-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optax.mask does not play well with flax FrozenParams #160
Comments
This is intended behaviour, and I agree that it is annoying :). There's some discussion (google/flax#1223) about changing Flax to return normal dictionaries instead of Frozen ones. In my own code, I immediately unfreeze parameters after initialising Flax modules. Does that help? I would prefer not to make a special case here for FrozenDicts, as it would break the mental model of the input label tree being the same structure (which includes type) as the actual params. |
@PhilipVinc can this issue be closed? It seems like progress on this would come from the Flax side not Optax. |
@PhilipVinc I get this error when using |
Flax now uses regular dicts per google/flax#3193. I'll close this issue, but feel free to open a new one if needed. |
The mask must be returned/given as a frozenDict, which is annoying.
I'm not sure this is really an optax bug... but could something be done to alleviate this?
This also shows up in
multi_transform
, where the fix is less obvious because it internally builds a dict, therefore the only way to make it work is to unfreeze the params before giving it to the optimiser, which is... inconvenient.The text was updated successfully, but these errors were encountered: