Apply L0 regularization (Learning Sparse Neural Networks through L0 Regularization) on CIFAR10 image classification task with GoogleNet model.
The class of L0 gate layer, sparsereg.model.basic_l0_blocks.L0Gate
, is modified from the author's repo.
Also, the CIFAR10 training part including model structure and dataloader are modified from TUTORIAL 4: INCEPTION, RESNET AND DENSENET. I just refactored into hydra and lightning(my style) format.
Main Package Version
hydra-core 1.2.0
pytorch-lightning 1.8.4.post0
torch 1.10.1+cu102
torchaudio 0.10.1+cu102
torchmetrics 0.11.0
torchvision 0.11.2+cu102
Define
Without L0 Gating Layer
- Training from scratch
python train.py ckpt_path=null resume_training=false without_using_gate=true
- Resume training from ckpt
python train.py ckpt_path=path_of_ckpt/last.ckpt resume_training=true without_using_gate=true
With L0 Gating Layer
- Training from scratch
python train.py ckpt_path=null resume_training=false without_using_gate=false lambda_l0=1.0
-
is inited by a pre-trained model (without L0), then fine tune with L0 regularization Settingpython train.py ckpt_path=path_of_ckpt/last.ckpt resume_training=false without_using_gate=false lambda_l0=20.0 droprate_init=0.05
droprate_init=0.05
in order to start training with the gate open (since the pre-trained model is trained without L0). Moreover, we makelambda_l0
(weight of regularization) larger than the one in "training from scratch", this is becasue the entropy loss () is well trained by the pre-trained model, so we can emphasize the L0 loss.
python test.py ckpt_path=path_of_ckpt/last.ckpt
tensorboard --logdir ./outputs/train/tblog/lightning_logs/
Recap the Loss function:
We know that
GoogleNet (sorted by Sparsity) | Validation Accuracy | Test Accuracy | Sparsity |
---|---|---|---|
(1) NO L0 | 90.12% | 89.57% | 0.0 |
(2) NN |
88.66% | 87.87% | 0.06 |
(3) Init NN by "NO L0" then fine tune with L0 , lambda=10. | 90.44% | 90.00% | 0.10 |
(4) NN |
86.9% | 86.56% | 0.22 |
(5) Init NN by "NO L0" then fine tune with L0 , lambda=20. | 88.8% | 88.62% | 0.39 |
(6) NN |
83.2% | 82.79% | 0.55 |
(7) Init NN by "NO L0" then fine tune with L0 , lambda=30. | 86.5% | 85.88% | 0.64 |
(8) Init NN by "NO L0" then fine tune with L0 , lambda=50. | 80.78% | 80.22% | 0.815 |
- "NO L0": trainig from scratch with NO L0 regularization term
- "NN
and L0 train from scratch": NN parameters and L0 are traing from scratch together - "Init NN by 'NO L0' then fine tune with L0":
is inited by "NO L0" model, then by setting droprate_init=0.05
(start training with the gate open), we fine tuneand L0 together
It is obvious that more pruned paramters harms more accuracy. So we can fine-tune
Moreover, by comparing (5) and (2), we can see that with a good initailzation of NN
Also see (7) and (6)
Finally, we show the values of
The drawback of L0 implementation in this repo is that training with L0 reg seems ~2 times slower than without L0. Maybe this is the next step of improvement. Moreover, I think unstructure pruning is a good way to get lower compression rate while keeping similar accuracy.
Let
Combined with entropy loss,
However, L0 regularization term is not differentiable. To cope with this issue, we apply a mask random variable
Therefore, we can rewrite
Also, we should rewrite the entropy loss,
To find the gradient w.r.t.
That's it! NN parameters qz_loga
in code and
Please see [L0 Regularization 詳細攻略] for detailed understanding about the math under the hood. Sorry only in Mandarin.
We prune the output channels of a convolution layer:
Then apply these L0Gate
for pruning channels in inception block:
Finally, GoogleNet is then constructed by these gated inception blocks.