Freezing a subset of the parameters using Flax and optax.masked #167
-
Hi all, I am trying to freeze a subset of the parameters during optimization using Flax and optax.masked. Consider a simple autoencoder consisting of an encoder and a decoder, each with two conv layers. params
encoder
Conv_0
kernel
bias
Conv_1
kernel
bias
decoder
Conv_0
kernel
bias
Conv_1
kernel
bias Now lets assume we want to freeze the encoder parameters. The mask might looks as follows: params
encoder False
decoder
Conv_0
kernel True
bias True
Conv_1
kernel True
bias True Then the optimizer is defined as follows: tx = optax.masked(optax.adam(learning_rate=1.0), mask) However, when I optimize the parameters, the encoder parameters will still get updated. No error messages are thrown. Can someone tell me what I am missing? I created a Colab with a reproducible example. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
@matthias-wright I believe this is addressed by this issue here. As explained there, |
Beta Was this translation helpful? Give feedback.
@matthias-wright I believe this is addressed by this issue here. As explained there,
optax.mask
essentially "zeros-out" the gradient transformation (i.e the gradients are not processed by the optimizer). The net result is that when wemask
parameters, we're not transforming the gradient before we apply the update to our parameters and so the gradient becomes the update. The solution, as explained in the above mention issue, is actually not to useoptax.mask
but instead to useoptax.multi_transform
where the gradients are "zeroed-out". As a bounus -- if the training loop isjit
-ed, the zero-ed out gradients are not even computed!