This repository contains a Pytorch implementation of the paper The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.
pip3 install -r requirements.txt
python3 main.py --prune_type=lt --arch_type=fc1 --dataset=mnist --prune_percent=10 --prune_iterations=35
--prune_type
: Type of pruning- Options :
lt
- Lottery Ticket Hypothesis,reinit
- Random reinitialization - Default :
lt
- Options :
--arch_type
: Type of architecture- Options :
fc1
- Simple fully connected network,lenet5
- LeNet5,AlexNet
- AlexNet,resnet18
- Resnet18,vgg16
- VGG16 - Default :
fc1
- Options :
--dataset
: Choice of dataset- Options :
mnist
,fashionmnist
,cifar10
,cifar100
- Default :
mnist
- Options :
--prune_percent
: Percentage of weight to be pruned after each cycle.- Default :
10
- Default :
--prune_iterations
: Number of cycle of pruning that should be done.- Default :
35
- Default :
--lr
: Learning rate- Default :
1.2e-3
- Default :
--batch_size
: Batch size- Default :
60
- Default :
--end_iter
: Number of Epochs- Default :
100
- Default :
--print_freq
: Frequency for printing accuracy and loss- Default :
1
- Default :
--valid_freq
: Frequency for Validation- Default :
1
- Default :
--gpu
: Decide Which GPU the program should use- Default :
0
- Default :
- Adding a new architecture :
- For example, if you want to add an architecture named
new_model
withmnist
dataset compatibility.- Go to
/archs/mnist/
directory and create a filenew_model.py
. - Now paste your Pytorch compatible model inside
new_model.py
. - IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your
new_model.py
matches with the corresponding dataset that you are adding (in this case, it ismnist
). - Now open
main.py
and go toline 36
and look for the comment# Data Loader
. Now find your corresponding dataset (in this case,mnist
) and addnew_model
at the end of the linefrom archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet
. - Now go to
line 82
and add the following to it :Here,elif args.arch_type == "new_model": model = new_model.new_model_name().to(device)
new_model_name()
is the name of the model that you have given insidenew_model.py
.
- Go to
- For example, if you want to add an architecture named
- Adding a new dataset :
- For example, if you want to add a dataset named
new_dataset
withfc1
architecture compatibility.- Go to
/archs
and create a directory namednew_dataset
. - Now go to /archs/new_dataset/
and add a file named
fc1.py` or copy paste it from existing dataset folder. - IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your
new_model.py
matches with the corresponding dataset that you are adding (in this case, it isnew_dataset
). - Now open
main.py
and gotoline 58
and add the following to it :Note that as of now, you can only add dataset that are natively available in Pytorch.elif args.dataset == "cifar100": traindataset = datasets.new_dataset('../data', train=True, download=True, transform=transform) testdataset = datasets.new_dataset('../data', train=False, transform=transform)from archs.new_dataset import fc1
- Go to
- For example, if you want to add a dataset named
- Go to
combine_plots.py
and add/remove the datasets/archs who's combined plot you want to generate (Assuming that you have already executed themain.py
code for those dataset/archs and produced the weights). - Run
python3 combine_plots.py
. - Go to
/plots/lt/combined_plots/
to see the graphs.
Kindly raise an issue if you have any problem with the instructions.
fc1 | LeNet5 | AlexNet | VGG16 | Resnet18 | |
---|---|---|---|---|---|
MNIST | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
CIFAR10 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
FashionMNIST | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
CIFAR100 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
Lottery-Ticket-Hypothesis-in-Pytorch
├── archs
│ ├── cifar10
│ │ ├── AlexNet.py
│ │ ├── densenet.py
│ │ ├── fc1.py
│ │ ├── LeNet5.py
│ │ ├── resnet.py
│ │ └── vgg.py
│ ├── cifar100
│ │ ├── AlexNet.py
│ │ ├── fc1.py
│ │ ├── LeNet5.py
│ │ ├── resnet.py
│ │ └── vgg.py
│ └── mnist
│ ├── AlexNet.py
│ ├── fc1.py
│ ├── LeNet5.py
│ ├── resnet.py
│ └── vgg.py
├── combine_plots.py
├── dumps
├── main.py
├── plots
├── README.md
├── requirements.txt
├── saves
└── utils.py
Parts of code were borrowed from ktkth5.
Open a new issue or do a pull request incase you are facing any difficulty with the code base or if you want to contribute to it.