Skip to content

bobondemon/l0_regularization_practice

Repository files navigation

Apply L0 regularization (Learning Sparse Neural Networks through L0 Regularization) on CIFAR10 image classification task with GoogleNet model.

The class of L0 gate layer, sparsereg.model.basic_l0_blocks.L0Gate, is modified from the author's repo. Also, the CIFAR10 training part including model structure and dataloader are modified from TUTORIAL 4: INCEPTION, RESNET AND DENSENET. I just refactored into hydra and lightning(my style) format.

Usage

Main Package Version

hydra-core             1.2.0
pytorch-lightning      1.8.4.post0
torch                  1.10.1+cu102
torchaudio             0.10.1+cu102
torchmetrics           0.11.0
torchvision            0.11.2+cu102

How to Train

Define θ and q are the parameters of NN and L0 Gating Layer respectively.

Without L0 Gating Layer

L ( θ ) = L E ( θ )
  • Training from scratch
    python train.py ckpt_path=null resume_training=false without_using_gate=true 
    
  • Resume training from ckpt
    python train.py ckpt_path=path_of_ckpt/last.ckpt resume_training=true without_using_gate=true
    

With L0 Gating Layer

L ( θ , q ) = L E ( θ , q ) + λ L C 0 ( q )
  • Training from scratch
    python train.py ckpt_path=null resume_training=false without_using_gate=false lambda_l0=1.0
    
  • θ is inited by a pre-trained model (without L0), then fine tune with L0 regularization
    python train.py ckpt_path=path_of_ckpt/last.ckpt resume_training=false without_using_gate=false lambda_l0=20.0 droprate_init=0.05
    
    Setting droprate_init=0.05 in order to start training with the gate open (since the pre-trained model is trained without L0). Moreover, we make lambda_l0 (weight of regularization) larger than the one in "training from scratch", this is becasue the entropy loss ( L E ) is well trained by the pre-trained model, so we can emphasize the L0 loss.

How to Test

python test.py ckpt_path=path_of_ckpt/last.ckpt

Monitoring with tensorboard

tensorboard --logdir ./outputs/train/tblog/lightning_logs/

Results

Recap the Loss function:

L ( θ , q ) = L E ( θ , q ) + λ L C 0 ( q )

We know that λ controls the importance of regulariztion term and hence the sparsity of the model. We define sparsity as:

sparsity = number of ZERO parameters number of ALL parameters
GoogleNet (sorted by Sparsity) Validation Accuracy Test Accuracy Sparsity
(1) NO L0 90.12% 89.57% 0.0
(2) NN θ and L0 train from scratch, lambda=0.25 88.66% 87.87% 0.06
(3) Init NN by "NO L0" then fine tune with L0 , lambda=10. 90.44% 90.00% 0.10
(4) NN θ and L0 train from scratch, lambda=0.5 86.9% 86.56% 0.22
(5) Init NN by "NO L0" then fine tune with L0 , lambda=20. 88.8% 88.62% 0.39
(6) NN θ and L0 train from scratch, lambda=1.0 83.2% 82.79% 0.55
(7) Init NN by "NO L0" then fine tune with L0 , lambda=30. 86.5% 85.88% 0.64
(8) Init NN by "NO L0" then fine tune with L0 , lambda=50. 80.78% 80.22% 0.815
  • "NO L0": trainig from scratch with NO L0 regularization term
  • "NN θ and L0 train from scratch": NN parameters θ and L0 ln α are traing from scratch together
  • "Init NN by 'NO L0' then fine tune with L0": θ is inited by "NO L0" model, then by setting droprate_init=0.05 (start training with the gate open), we fine tune θ and L0 ln α together

It is obvious that more pruned paramters harms more accuracy. So we can fine-tune λ to control the compression rate (sparsity) in demand.

Moreover, by comparing (5) and (2), we can see that with a good initailzation of NN θ , we can get a better sparsity with similar accuracy than just training from scratch.

Also see (7) and (6)

Finally, we show the values of L C 0 in (2), (4), and (6) during training with different λ below:

The drawback of L0 implementation in this repo is that training with L0 reg seems ~2 times slower than without L0. Maybe this is the next step of improvement. Moreover, I think unstructure pruning is a good way to get lower compression rate while keeping similar accuracy.

Introduction to L0 Regularization

Motivation

Let θ be the parameters of our model, and we hope there is only a small number of non-zero parameters. Zero-norm measures this number so the L0 regularization term, L C 0 , can be defined as:

L C 0 ( θ ) = θ 0 = j = 1 | θ | I [ θ j 0 ]

Combined with entropy loss, L E , forms the final loss L :

L E ( θ ) = 1 N ( i = 1 N L ( N N ( x i ; θ ) , y i ) )   L ( θ ) = L E ( θ ) + λ L C 0 ( θ )

However, L0 regularization term is not differentiable. To cope with this issue, we apply a mask random variable Z = Z 1 , . . . , Z | θ | which each Z i follows a Bernoulli distribution with parameter q i .

Therefore, we can rewrite L C 0 in a closed form:

L C 0 ( θ , q ) = E Z Bernoulli ( q ) [ j = 1 | θ | I [ θ j Z j 0 ] ] = E Z Bernoulli ( q ) [ j = 1 | θ | Z j ] = j | θ | q j

Also, we should rewrite the entropy loss, L E , accordingly:

L E ( θ , q ) = E Z Bernoulli ( q ) [ 1 N ( i = 1 N L ( N N ( x i ; θ Z i ) , y i ) ) ]   L ( θ , q ) = L E ( θ , q ) + λ L C 0 ( q )

To find the gradient w.r.t. q in the entropy loss is not trivial, since we cannot merely exchange the expectation and the differential operations. Fortunately, by using Gumbel-Softmax re-parameterization trick, we can make the random sampling (expectation on Bernoulli distribution) becomes independent on q . So that the entropy loss becomes differentiable now.

That's it! NN parameters θ and the mask's parameters (qz_loga in code and ln α in the following figures) are now can be updated by backpropagation!

Please see [L0 Regularization 詳細攻略] for detailed understanding about the math under the hood. Sorry only in Mandarin.

Structure pruning with L0 norm

We prune the output channels of a convolution layer:

Then apply these L0Gate for pruning channels in inception block:

Finally, GoogleNet is then constructed by these gated inception blocks.

Releases

No releases published

Packages

No packages published

Languages