Skip to content

Freezing a subset of the parameters using Flax and optax.masked #167

Discussion options

You must be logged in to vote

@matthias-wright I believe this is addressed by this issue here. As explained there, optax.mask essentially "zeros-out" the gradient transformation (i.e the gradients are not processed by the optimizer). The net result is that when we mask parameters, we're not transforming the gradient before we apply the update to our parameters and so the gradient becomes the update. The solution, as explained in the above mention issue, is actually not to use optax.mask but instead to use optax.multi_transform where the gradients are "zeroed-out". As a bounus -- if the training loop is jit-ed, the zero-ed out gradients are not even computed!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@matthias-wright
Comment options

Answer selected by matthias-wright
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants