diff --git a/flax/linen/dotgetter.py b/flax/linen/dotgetter.py deleted file mode 100644 index af03c0294d..0000000000 --- a/flax/linen/dotgetter.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simple syntactic wrapper for nested dictionaries to allow dot traversal.""" -from collections.abc import MutableMapping # pylint: disable=g-importing-member -from flax import serialization -from flax.core.frozen_dict import FrozenDict -from jax import tree_util - - -def is_leaf(x): - return tree_util.treedef_is_leaf(tree_util.tree_flatten(x)[1]) - - -# TODO(jheek): remove pytype hack, probably the MutableMapping -# inheritance should be dropped. - - -# We subclass MutableMapping for automatic dict-like utility fns. -# We subclass dict so that freeze, unfreeze work transparently: -# i.e freeze(DotGetter(d)) == freeze(d) -# unfreeze(DotGetter(d)) == unfreeze(d) -class DotGetter(MutableMapping, dict): # type: ignore[misc] # pytype: disable=mro-error - """Dot-notation helper for interactive access of variable trees.""" - __slots__ = ('_data',) - - def __init__(self, data): - # Because DotGetter has an MRO error, calling `super().__init__()` is - # ambiguous. Therefore we call it on the `dict` superclass - # (`MutableMapping` is an ABC). - super(dict, self).__init__() # pylint: disable=bad-super-call - object.__setattr__(self, '_data', data) - - def __getattr__(self, key): - if is_leaf(self._data[key]): # Returns leaves unwrapped. - return self._data[key] - else: - return DotGetter(self._data[key]) - - def __setattr__(self, key, val): - if isinstance(self._data, FrozenDict): - raise ValueError("Can't set value on FrozenDict.") - self._data[key] = val - - def __getitem__(self, key): - return self.__getattr__(key) - - def __setitem__(self, key, val): - self.__setattr__(key, val) - - def __delitem__(self, key): - if isinstance(self._data, FrozenDict): - raise ValueError("Can't delete value on FrozenDict.") - del self._data[key] - - def __iter__(self): - return iter(self._data) - - def __len__(self): - return len(self._data) - - def __keytransform__(self, key): - return key - - def __dir__(self): - if isinstance(self._data, dict): - return list(self._data.keys()) - elif isinstance(self._data, FrozenDict): - return list(self._data._dict.keys()) - else: - return [] - - def __repr__(self): - return f'{self._data}' - - def __hash__(self): - # Note: will only work when wrapping FrozenDict. - return hash(self._data) - - def copy(self, **kwargs): - return self._data.__class__(self._data.copy(**kwargs)) - -tree_util.register_pytree_node( - DotGetter, - lambda x: ((x._data,), ()), # pylint: disable=protected-access - lambda _, data: data[0]) # type: ignore - -# Note: restores as raw dict, intentionally. -serialization.register_serialization_state( - DotGetter, - serialization._dict_state_dict, # pylint: disable=protected-access - serialization._restore_dict) # pylint: disable=protected-access diff --git a/tests/linen/dotgetter_test.py b/tests/linen/dotgetter_test.py deleted file mode 100644 index fb060e77a7..0000000000 --- a/tests/linen/dotgetter_test.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -import jax -from jax import random -import jax.numpy as jnp -import numpy as np - -from flax import linen as nn -from flax.core import Scope, FrozenDict, freeze, unfreeze -from flax.linen.dotgetter import DotGetter, is_leaf -from flax import serialization - -# Parse absl flags test_srcdir and test_tmpdir. -#jax.config.parse_flags_with_absl() - -class DotGetterTest(absltest.TestCase): - - def test_simple(self): - dg = DotGetter({'a': 1, 'b': {'c': 2}, 'd': {'e': {'f': 3}}}) - self.assertEqual(dg.a, 1) - self.assertEqual(dg.b.c, 2) - self.assertEqual(dg.d.e.f, 3) - self.assertEqual(dg['a'], 1) - self.assertEqual(dg['b'].c, 2) - self.assertEqual(dg['b']['c'], 2) - self.assertEqual(dg.b['c'], 2) - self.assertEqual(dg.d.e.f, 3) - - def test_simple_frozen(self): - dg = DotGetter(freeze({'a': 1, 'b': {'c': 2}, 'd': {'e': {'f': 3}}})) - self.assertEqual(dg.a, 1) - self.assertEqual(dg.b.c, 2) - self.assertEqual(dg.d.e.f, 3) - self.assertEqual(dg['a'], 1) - self.assertEqual(dg['b'].c, 2) - self.assertEqual(dg['b']['c'], 2) - self.assertEqual(dg.b['c'], 2) - self.assertEqual(dg.d.e.f, 3) - - def test_eq(self): - dg1 = DotGetter({'a': 1, 'b': {'c': 2, 'd': 3}}) - dg2 = DotGetter({'a': 1, 'b': {'c': 2, 'd': 3}}) - self.assertEqual(dg1, dg2) - self.assertEqual(freeze(dg1), dg2) - self.assertEqual(freeze(dg1), freeze(dg2)) - - def test_dir(self): - dg = DotGetter({'a': 1, 'b': {'c': 2, 'd': 3}}) - self.assertEqual(dir(dg), ['a', 'b']) - self.assertEqual(dir(dg.b), ['c', 'd']) - - def test_freeze(self): - d = {'a': 1, 'b': {'c': 2, 'd': 3}} - dg = DotGetter(d) - self.assertEqual(freeze(dg), freeze(d)) - fd = freeze({'a': 1, 'b': {'c': 2, 'd': 3}}) - fdg = DotGetter(d) - self.assertEqual(unfreeze(fdg), unfreeze(fd)) - - def test_hash(self): - d = {'a': 1, 'b': {'c': 2, 'd': 3}} - dg = DotGetter(d) - fd = freeze(d) - fdg = DotGetter(fd) - self.assertEqual(hash(fdg), hash(fd)) - with self.assertRaisesRegex(TypeError, 'unhashable'): - hash(dg) - - def test_pytree(self): - dg1 = DotGetter({'a': jnp.array([1.0]), - 'b': {'c': jnp.array([2.0]), - 'd': jnp.array([3.0])}}) - dg2 = DotGetter({'a': jnp.array([2.0]), - 'b': {'c': jnp.array([4.0]), - 'd': jnp.array([6.0])}}) - self.assertEqual(jax.tree_util.tree_map(lambda x: 2 * x, dg1), dg2) - - def test_statedict(self): - d = {'a': jnp.array([1.0]), - 'b': {'c': jnp.array([2.0]), - 'd': jnp.array([3.0])}} - dg = DotGetter(d) - ser = serialization.to_state_dict(dg) - deser = serialization.from_state_dict(dg, ser) - self.assertEqual(d, deser) - - def test_is_leaf(self): - for x in [0, 'foo', jnp.array([0.]), {}, [], (), {1, 2}]: - self.assertTrue(is_leaf(x)) - self.assertFalse(is_leaf({'a': 1})) - self.assertFalse(is_leaf([1,2,3])) - self.assertFalse(is_leaf((1,2,3))) - - -if __name__ == '__main__': - absltest.main()