Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regular dict representation not indented #3280

Open
chiamp opened this issue Aug 15, 2023 · 3 comments
Open

Regular dict representation not indented #3280

chiamp opened this issue Aug 15, 2023 · 3 comments

Comments

@chiamp
Copy link
Collaborator

chiamp commented Aug 15, 2023

After the dict migration, Flax now returns regular dicts when calling the .init, .init_with_output and .apply Module methods. However the representation of regular dicts are not as readable compared to the indented version of FrozenDicts.

Regular dicts:

class MLP(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(5)(x)
    x = nn.relu(x)
    return x
x = jnp.ones((1,3))
model = MLP()
params = model.init(jax.random.PRNGKey(0), x)['params']
state = TrainState.create(apply_fn=model.apply, params=params, tx=optax.adam(1e-3))
state
TrainState(step=0, apply_fn=<bound method Module.apply of MLP()>, params={'Dense_0': {'kernel': Array([[ 0.37229332, -0.4265755 , -1.1151816 , -0.09558704, -0.62169886],
       [-1.060781  ,  1.0546707 ,  0.33051118, -0.7090655 ,  0.37682843],
       [-0.30747807, -0.39064118, -0.25515485,  0.5127583 , -0.5559202 ]],      dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fe6daeca4d0>, update=<function chain.<locals>.update_fn at 0x7fe6daeca9e0>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu={'Dense_0': {'bias': Array([0., 0., 0., 0., 0.], dtype=float32), 'kernel': Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)}}, nu={'Dense_0': {'bias': Array([0., 0., 0., 0., 0.], dtype=float32), 'kernel': Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)}}), EmptyState()))

FrozenDicts:

state = TrainState.create(apply_fn=model.apply, params=flax.core.freeze(params), tx=optax.adam(1e-3))
state
TrainState(step=0, apply_fn=<bound method Module.apply of MLP()>, params=FrozenDict({
    Dense_0: {
        kernel: Array([[ 0.37229332, -0.4265755 , -1.1151816 , -0.09558704, -0.62169886],
               [-1.060781  ,  1.0546707 ,  0.33051118, -0.7090655 ,  0.37682843],
               [-0.30747807, -0.39064118, -0.25515485,  0.5127583 , -0.5559202 ]],      dtype=float32),
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fe710e21ea0>, update=<function chain.<locals>.update_fn at 0x7fe710e225f0>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu=FrozenDict({
    Dense_0: {
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
        kernel: Array([[0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.]], dtype=float32),
    },
}), nu=FrozenDict({
    Dense_0: {
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
        kernel: Array([[0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.]], dtype=float32),
    },
})), EmptyState()))

The indented representation can be viewed by calling flax.core.pretty_repr on the dict. Alternatively we could subclass dict and override the __repr__ method to return an indented representation and have Flax return this subclass when .init, .init_with_output and .apply are called:

@flax.struct.dataclass
class MutableDict(dict):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
  def __repr__(self):
    return 'MutableDict' + flax.core.pretty_repr(self)

state = TrainState.create(apply_fn=model.apply, params=MutableDict(params), tx=optax.adam(1e-3))
state
TrainState(step=0, apply_fn=<bound method Module.apply of MLP()>, params=MutableDict{
    Dense_0: {
        kernel: Array([[ 0.37229332, -0.4265755 , -1.1151816 , -0.09558704, -0.62169886],
               [-1.060781  ,  1.0546707 ,  0.33051118, -0.7090655 ,  0.37682843],
               [-0.30747807, -0.39064118, -0.25515485,  0.5127583 , -0.5559202 ]],      dtype=float32),
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fe710e3a5f0>, update=<function chain.<locals>.update_fn at 0x7fe710e39990>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu=MutableDict{}, nu=MutableDict{}), EmptyState()))

Another option is to add a section in the dict migration guide to let users know they can get the indented representation by calling flax.core.pretty_repr (although this currently works only on FrozenDicts and regular dicts, and not other objects like TrainState).

@chiamp
Copy link
Collaborator Author

chiamp commented Aug 15, 2023

@cgarciae @marcvanzee

@chiamp chiamp changed the title Regular dict representation Regular dict representation not indented Aug 15, 2023
@cgarciae
Copy link
Collaborator

cgarciae commented Aug 24, 2023

As discussed internally, we have to make MutableDict a proper pytree. Here is an idea for the implementation:

from typing import Iterable, Mapping, MutableMapping, TypeVar, Union

import flax
import jax

A = TypeVar('A')
B = TypeVar('B')


class MutableDict(MutableMapping[A, B]):

  def __init__(
      self,
      input: Union[Mapping[A, B], Iterable[tuple[A, B]], None] = None,
      /,
      **kwargs: B,
  ):
    self._dict: dict[A, B] = dict(input, **kwargs) if input else dict(**kwargs)

  def __setitem__(self, key: A, value: B) -> None:
    self._dict[key] = value

  def __getitem__(self, key: A) -> B:
    value = self._dict[key]
    if isinstance(value, dict) and not isinstance(value, MutableDict):
      return MutableDict(value)  # type: ignore
    return value

  def __delitem__(self, key: A) -> None:
    del self._dict[key]

  def __iter__(self):
    return iter(self._dict)

  def __len__(self):
    return len(self._dict)

  def __repr__(self):
    return 'MutableDict(' + flax.core.pretty_repr(self._dict) + ')'


jax.tree_util.register_pytree_with_keys(
    MutableDict,
    lambda d: (
        tuple(
            (jax.tree_util.DictKey(key), value)
            for key, value in d._dict.items()
        ),
        tuple(d._dict.keys()),
    ),
    lambda keys, values: MutableDict(zip(keys, values)),
    lambda d: ((d._dict,), None),
)

d = MutableDict({'a': 1, 'b': {'c': 2, 'd': 3}})

print('\nprint\n--------------------------')
print(d)
print('\naccess\n--------------------------')
print(d['b'])
print('\ntree_flatten\n--------------------------')
print(jax.tree_util.tree_flatten(d)[0])
print('\ntree_flatten_with_path\n--------------------------')
print(jax.tree_util.tree_flatten_with_path(d)[0])
print
--------------------------
MutableDict({
    a: 1,
    b: {
        c: 2,
        d: 3,
    },
})

access
--------------------------
MutableDict({
    c: 2,
    d: 3,
})

tree_flatten
--------------------------
[1, 2, 3]

tree_flatten_with_path
--------------------------
[((DictKey(key='a'),), 1), ((DictKey(key='b'), DictKey(key='c')), 2), ((DictKey(key='b'), DictKey(key='d')), 3)]

I decided not to inherit from dict as it leads to some optimizations but instead implemented the MutableMappging protocol.

@cgarciae
Copy link
Collaborator

We might need to do some deep checks so we don't have nested MutableDicts similar to what FrozenDict has.
I am wondering if we should just promote/expose flax.core.pretty_print as nn.pretty_print and use it in the guides so pick it up?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants