Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should Flax return FrozenDicts or regular dicts? #1223

Closed
marcvanzee opened this issue Apr 8, 2021 · 26 comments
Closed

Should Flax return FrozenDicts or regular dicts? #1223

marcvanzee opened this issue Apr 8, 2021 · 26 comments
Assignees
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@marcvanzee
Copy link
Collaborator

This topic is discussed regularly internally, and I feel we haven't reached a consensus here. Below are some arguments collected from users for both positions, feel free to add.

Arguments in favor of FrozenDict

  • @avital: If you use normal dicts, it is easy to mutate them, which means the behavior may differ depending on whether the function in which the modification is made is jitted or not. Example:
def f(params):
  params['conv1']['weight'] = ...
  return ...some computation over params

params = load_from_checkpoint()
print(f(params))
# now what is the value of params['conv1']['weight']?
# depending on whether f is jitted or not, you'd get different results

Arguments in favor of regular dicts

  • @lucasb-eyer: Flax tells me "here's these precious weights, please hold them for me and give them back to me later on, but DONT TOUCH" it begs the question: why give them to me in the first place, if I'm not supposed to do anything with it?

  • @avital: I also think it'd be better for Flax to return normal Python dicts, but still use FrozenDict within modules (via the mutable argument to apply).

@marcvanzee marcvanzee self-assigned this Apr 8, 2021
@marcvanzee marcvanzee added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Apr 8, 2021
@n2cholas
Copy link
Contributor

n2cholas commented Apr 12, 2021

I think the Python saying "We're all consenting adults here" is pretty fitting. In my view, trading convenience for safety is reasonable here because JAX users should know (or will quickly come to learn) that under the JAX transformations, they should not mutate state. Since FrozenDicts are not as ergonomic as normal dicts, I tend to unfreeze them as soon as they're returned from init anyway.

Though I would prefer the user-facing API to just use dict, I wouldn't mind FrozenDict if the behaviour was closer to dict for non-mutating cases. In particular,

  • The ability to pass FrozenDict to flax.traverse_util.flatten_dict.
  • The views FrozenDict.keys() and FrozenDict.values() should have a similar style __repr__ to normal dicts so they are easy to inspect interactively in notebooks (right now they just show the whole FrozenDict)
  • The ability to update/merge with normal dicts

Explicit state management is one of my favourite aspects of Flax, as it gives me the ability to transparently manipulate modules/parameters without worrying about hidden side effects. I totally agree with @lucasb-eyer's point that it's counterproductive to provide explicit state without allowing the user to fully control it.

@jheek
Copy link
Member

jheek commented Apr 13, 2021

I think the Python saying "We're all consenting adults here" is pretty fitting

Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link).
It's a hard issue to fix though because mutability "infects" all of your code and Python isn't a functional language.

That said, I don't think FrozenDict has shown to be a very effective tool safety tool to avoid this kind of error. We should probably keep using it internally to avoid accidental reference sharing but for users it seems to big a burden while it doesn't avoid the more common issue of closing over mutable state (typically created by the user) or using things like np.random in a jitted function.

I do think we should at least provide an easy way to clone a pytree if we allow it to contain mutable containers. Something like the following:

def clone_pytree(xs):
  # cloning is just an identity mapping
  return jax.tree_map(lambda x: x, xs)

def some_nested_transformation():
  my_copy = flax.traverse_util.clone_pytree(variables)
  my_copy['batch_stats']['x'] += 2.
  return my_copy

Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.

@n2cholas
Copy link
Contributor

Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link).
It's a hard issue to fix though because mutability "infects" all of your code and Python isn't a functional language.

This is a good point, especially as a codebase grows it can sneak past you. ​Personally, if FrozenDict better matched the ergonomics of dict outside of mutation, I would not see it as a burden, at least for my own use cases.

Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.

I actually quite like the immutable dataclasses, since the .replace(...) API is similar namedtuples. In my view, the inconveniences that arise with FrozenDict don't happen here since dataclasses don't have arbitrary structure and you don't generally manipulate the that structure.

@marcvanzee
Copy link
Collaborator Author

Thanks for the input @n2cholas! After chatting with @jheek offline, the consensus is that it is indeed useful to return regular dicts, but that we block implementing this on merging the chex and flax dataclasses.

@lucasb-eyer
Copy link
Contributor

Gear, very happy about this decision. I'd just like to add that

See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link).

Is a complete red herring. This is about hidden global state, whereas this discussion is specifically about explicit, non-global state. It's actually more about rng design than anything else, and what we are talking about doing here is already the "better" rng design where the user explicitly is given, and trusted to correctly handle, the state.

@PhilipVinc
Copy link
Contributor

Is there any further development on this?

@marcvanzee
Copy link
Collaborator Author

marcvanzee commented Sep 6, 2021

Sorry for the delay -- I was on parental leave.

@jheek could you tell us whether any progress has been made on merging the chex and flax dataclasses?

@NeilGirdhar
Copy link
Contributor

What does merging the dataclasses consist of? Are flax dataclasses going to be inheriting the mapping interface?

@jheek
Copy link
Member

jheek commented Dec 7, 2021

The merging of dataclasses is taking much longer than originally anticipated. I'll bring this up in our next sync meeting because I think we should start to move towards allowing mutability independently of actually merging the implementations witch chex

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Dec 7, 2021

I think we should start to move towards allowing mutability

Sorry, but why would you do that?

The merging of dataclasses is taking much longer than originally anticipated. I

Also, I stil don't understand what this merge will consist of. Flax's dataclasses are well-designed: They are just frozen dataclasses that register as pytrees, have a field function that conveniently supports marking static fields, and add a replace method. Besides the replace method (which is just a shortcut to dataclasses.replace), this is a minimal interface.

Chex datacasses are badly designed: they are not frozen, they can't mark static fields, and they unnecessarily expose the whole mapping interface, which means you can access fields as attributes or keys. They also expose a to_tuple method that is inferior to dataclasses.astuple, which supports nested dataclasses. The from_tuple method is also somewhat flimsy since it won't work with Python 3.10's new keyword-only arguments. This is not a minimal interface.

I was hoping to ditch tjax's dataclasses in favor of flax's, but if you're merging in any of chex's behavior, I won't be able to.

@jheek
Copy link
Member

jheek commented Dec 7, 2021

We won't be removing features like frozen, static fields, and replace.
We do however want to be less strict about enforcing functional patterns.
Many users find it difficult to deal with frozen dataclasses/dicts. At the end of the day Python is not a functional language and partially making it behave like one can be awkward.

As for the mapping interface. This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types. This is also why chex cannot easily add static fields because tf.nest doesn't support it. We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Dec 7, 2021

Many users find it difficult to deal with frozen dataclasses/dicts. At the end of the day Python is not a functional language and partially making it behave like one can be awkward.

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a brilliant idea.

For statistics, in my 5500 line Jax project, I call replace 9 times. It may be slightly more awkward than writing to attributes, but I don't think it's worth giving up the safety of all of the methods on my dataclasses being verified to be pure.

This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types.

Instead having a gigantic interface and passing the dataclass d to tf.nest, can't users pass dataclasses.asdict(d)?

This is also why chex cannot easily add static fields because tf.nest doesn't support it.

I see. Why not create an asdict function that removes the keys corresponding to static fields? Or more conveniently, convince Tensorflow to check for an as_dynamic_dict method and call it in tf.nest?

We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types.

Yes! Thank you!

@jheek
Copy link
Member

jheek commented Dec 7, 2021

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a good idea.

Yes, this is the tradeoff we have to think about and this we will discuss this further before making a final decision.

@avital
Copy link
Contributor

avital commented Dec 7, 2021

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a brilliant idea.

What is a particular form of this problem? Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc). I understand the need to ensure frozen data structures within modules (and we're not proposing this changes -- module.apply will still have a mutable argument and use FrozenDicts based on that). The only proposed change that I am aware of is changing the signatures of module.init and module.apply to not return FrozenDicts.

@PhilipVinc
Copy link
Contributor

By the way, I also think that Flax returning frozen dictionaries is extremely annoying. Changing this behaviour would also address google-deepmind/optax#160

Moreover, our (NetKet) users and students learning Jax/Flax find it often confusing why they keep getting this object that they have to melt to edit.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Dec 7, 2021

The only proposed change that I am aware of is changing the signatures of module.init and module.apply to not return FrozenDicts.

Sorry, I'm not actually discussing the topic of the issue. I just noticed a comment about merging chex.dataclass, and I wanted some clarification on that.

What is a particular form of this problem?

I can't find the example, but I saw one with treex (which doesn't enforce frozen dataclasses) where someone was doing

def f(x):
    x.some_member = some_value
    return x

@jit
def g(...):
    ...
    x = f(x)  # if you forget to assign to x, you will get different behavior for the jitted and unjitted function. 

Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc).

Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class.

@avital
Copy link
Contributor

avital commented Dec 16, 2021

Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc).

Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class.

@NeilGirdhar I just mean things like updating state and params and reporting metrics -- it's totally fine to directly manipulate the variables dict in the main training loop, and people have to jump through (IMHO unnecessary) hoops to achieve this: #1729 (comment).

@NeilGirdhar
Copy link
Contributor

@avital Fair enough. I need to learn Flax better before I can really suggest something. A couple other options:

An at operator that does this under the covers, so that you can write:

embedding = params['params']['Embed_0']['embedding']
norm = jnp.linalg.norm(embedding, axis=-1, keepdims=True)
new_params = params.at['params']['Embed_0']['embedding'].divide(norm + 1e-10)
state = state.replace(params=new_params)

The at operator would return a handle like the one in jax.index_ops.

Or maybe a context manager that provides the handle and automatically rolls it back in when it ends:

with state.unfreeze() as unfrozen_state:
    unfrozen_state.params['params']['Embed_0']['embedding'] /= (jnp.linalg.norm(state.params['params']['Embed_0']['embedding'], axis=-1, keepdims=True) + 1e-10)

You'd still be jumping through hoops, but it's just one hoop.

@avital
Copy link
Contributor

avital commented Dec 16, 2021

The problem with any hoop isn't it's complexity -- it's that it's something you have to learn suddenly, when you "just wanted to try this one thing". So any hoop should be justified by the benefit it gives you (hopefully a lot). Maybe I'm just misunderstanding this but I never understood the benefit of having module.apply and module.init return FrozenDicts. (I've always been strongly in support of FrozenDicts inside modules, which happens internally as a function of the mutable argument to module.apply)

@avital
Copy link
Contributor

avital commented Dec 16, 2021

I guess another way to put it -- if someone really wants immutable data structures, they can always do, e.g. FrozenDict(module.init(...)). So the question is: which default serves the users best?

@lucasb-eyer
Copy link
Contributor

And the answer is just plain dict, at least for this user here :)

@cgarciae
Copy link
Collaborator

+1 for this! I have a lot of code that immediately calls .unfreeze() right after init and apply.

@cgarciae
Copy link
Collaborator

Hey @NeilGirdhar! I believe you're looking for this example from Treex's User Guide.

@marcvanzee marcvanzee removed their assignment Dec 12, 2022
@cgarciae
Copy link
Collaborator

Since this would be a breaking change, we should bump Flax's version to avoid breaking OS user's using semantic versioning.

@marcvanzee
Copy link
Collaborator Author

FYI: @chiamp is going to look into this

@chiamp
Copy link
Collaborator

chiamp commented Aug 31, 2023

Closing after #3193 landed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants