Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optax.mask does not play well with flax FrozenParams #160

Closed
PhilipVinc opened this issue Jul 15, 2021 · 4 comments
Closed

optax.mask does not play well with flax FrozenParams #160

PhilipVinc opened this issue Jul 15, 2021 · 4 comments

Comments

@PhilipVinc
Copy link

The mask must be returned/given as a frozenDict, which is annoying.
I'm not sure this is really an optax bug... but could something be done to alleviate this?

This also shows up in multi_transform, where the fix is less obvious because it internally builds a dict, therefore the only way to make it work is to unfreeze the params before giving it to the optimiser, which is... inconvenient.

>>> import jax.numpy as jnp
>>> import jax
>>> import optax
>>> from flax.core import freeze, unfreeze
>>> 
>>> pars = freeze({"Dense": {"kernel": jnp.zeros((2,3)), "bias": jnp.zeros((3))}, "bias":jnp.zeros(2)})
>>> op = optax.masked(optax.sgd(0.1), {"Dense": True, "bias": False})
>>> op.init(pars)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/filippovicentini/Documents/pythonenvs/netket_env/lib64/python3.8/site-packages/optax/_src/wrappers.py", line 311, in init_fn
    flat_params = treedef.flatten_up_to(params)
ValueError: Expected dict, got FrozenDict({
    Dense: {
        kernel: DeviceArray([[0., 0., 0.],
                     [0., 0., 0.]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
    bias: DeviceArray([0., 0.], dtype=float32),
}).
@n2cholas
Copy link
Contributor

This is intended behaviour, and I agree that it is annoying :). There's some discussion (google/flax#1223) about changing Flax to return normal dictionaries instead of Frozen ones. In my own code, I immediately unfreeze parameters after initialising Flax modules.

Does that help? I would prefer not to make a special case here for FrozenDicts, as it would break the mental model of the input label tree being the same structure (which includes type) as the actual params.

@n2cholas
Copy link
Contributor

@PhilipVinc can this issue be closed? It seems like progress on this would come from the Flax side not Optax.

@deadsoul44
Copy link

@PhilipVinc I get this error when using multi_transform. What is the solution? I put unfreeze everywhere but it didn't work.

@vroulet
Copy link
Collaborator

vroulet commented Feb 5, 2024

Flax now uses regular dicts per google/flax#3193. I'll close this issue, but feel free to open a new one if needed.

@vroulet vroulet closed this as completed Feb 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants