sparsity regularisation #235
-
Hi, how can I add an L1 regularization to a subset of parameters? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 8 replies
-
Hi! Thanks a lot for the question! The As far as I am aware, optax currently does not have a gradient transformation for L1 regularisation but this should be easy to implement by mirroring what happens in the L2 case ( I hope this helps! Let me know if you have any questions! |
Beta Was this translation helpful? Give feedback.
Hi! Thanks a lot for the question!
The
optax.masked
wrapper can be used to transform only a subset of parameters using a masking function. The docstring ofoptax.masked
has an example with L2 regularisation (usingadd_decayed_weights
). In the case of weight decay using a mask is so common thatadd_decayed_weights
has amask
option, which usesoptax.masked
under the hood.As far as I am aware, optax currently does not have a gradient transformation for L1 regularisation but this should be easy to implement by mirroring what happens in the L2 case (
optax.add_decayed_weights
). We should definitely add the L1 functionality too though; would you be keen to implement it and file a PR? No worrie…