-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regular dict representation not indented #3280
Comments
As discussed internally, we have to make from typing import Iterable, Mapping, MutableMapping, TypeVar, Union
import flax
import jax
A = TypeVar('A')
B = TypeVar('B')
class MutableDict(MutableMapping[A, B]):
def __init__(
self,
input: Union[Mapping[A, B], Iterable[tuple[A, B]], None] = None,
/,
**kwargs: B,
):
self._dict: dict[A, B] = dict(input, **kwargs) if input else dict(**kwargs)
def __setitem__(self, key: A, value: B) -> None:
self._dict[key] = value
def __getitem__(self, key: A) -> B:
value = self._dict[key]
if isinstance(value, dict) and not isinstance(value, MutableDict):
return MutableDict(value) # type: ignore
return value
def __delitem__(self, key: A) -> None:
del self._dict[key]
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return 'MutableDict(' + flax.core.pretty_repr(self._dict) + ')'
jax.tree_util.register_pytree_with_keys(
MutableDict,
lambda d: (
tuple(
(jax.tree_util.DictKey(key), value)
for key, value in d._dict.items()
),
tuple(d._dict.keys()),
),
lambda keys, values: MutableDict(zip(keys, values)),
lambda d: ((d._dict,), None),
)
d = MutableDict({'a': 1, 'b': {'c': 2, 'd': 3}})
print('\nprint\n--------------------------')
print(d)
print('\naccess\n--------------------------')
print(d['b'])
print('\ntree_flatten\n--------------------------')
print(jax.tree_util.tree_flatten(d)[0])
print('\ntree_flatten_with_path\n--------------------------')
print(jax.tree_util.tree_flatten_with_path(d)[0])
I decided not to inherit from |
We might need to do some deep checks so we don't have nested |
After the dict migration, Flax now returns regular dicts when calling the
.init
,.init_with_output
and.apply
Module methods. However the representation of regular dicts are not as readable compared to the indented version of FrozenDicts.Regular dicts:
FrozenDicts:
The indented representation can be viewed by calling
flax.core.pretty_repr
on the dict. Alternatively we could subclass dict and override the__repr__
method to return an indented representation and have Flax return this subclass when.init
,.init_with_output
and.apply
are called:Another option is to add a section in the dict migration guide to let users know they can get the indented representation by calling
flax.core.pretty_repr
(although this currently works only on FrozenDicts and regular dicts, and not other objects like TrainState).The text was updated successfully, but these errors were encountered: