diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index e532632676..fc8af3f70f 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -35,7 +35,6 @@ from .graph import PureState as PureState from .object import Object as Object from .helpers import Dict as Dict -from .helpers import List as List from .helpers import Sequential as Sequential from .helpers import TrainState as TrainState from .module import M as M diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 5468a5a987..3d78fc8261 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -16,7 +16,6 @@ import contextlib import dataclasses -import enum import functools import threading import typing as tp @@ -743,110 +742,6 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]): f'Unsupported update type: {type(value)} for key {key!r}' ) - -class _StaticModuleStatus(enum.Enum): - NEW = enum.auto() - UPDATED = enum.auto() - - -# TODO(cgarciae): remove once transform init are reimplemented -def update_from(node: Node, updates: Node) -> None: - graph_update_static(node, updates) - _, state = split(updates) - update(node, state) - - -# TODO(cgarciae): remove once transform init are reimplemented -def graph_update_static(node: Node, updates: Node) -> None: - cache: dict[int, _StaticModuleStatus] = {} - _graph_update_static(node, updates, cache, _StaticModuleStatus.UPDATED, ()) - - -def _graph_update_static( - node: Node, - updates: Node, - cache: dict[int, _StaticModuleStatus], - status: _StaticModuleStatus, - path: PathParts, -) -> None: - if type(node) != type(updates): - raise ValueError( - f'Trying to update a node with a different type: ' - f'expected {type(node).__name__!r}, ' - f'but got {type(updates).__name__!r}' - ) - if not is_node(node): - raise ValueError(f'Unsupported node type: {type(node)}') - - if id(updates) in cache: - if cache[id(updates)] != status: - str_path = '/'.join(str(p) for p in path) - if status is _StaticModuleStatus.NEW: - raise ValueError( - f'Trying to add a new node at path {str_path!r} but a' - ' node with the same reference has been updated' - ) - else: - raise ValueError( - f'Trying to update a node at path {str_path!r} but a new' - ' node with the same reference has been added' - ) - return - - cache[id(updates)] = status - - node_impl = get_node_impl(node) - node_dict = node_impl.node_dict(node) - updates_dict = node_impl.node_dict(updates) - for name, value_updates in updates_dict.items(): - # case 1: trying to update a Variable, skip - if is_state_leaf(value_updates): - continue - elif is_node(value_updates): - # case 2: updating an existing subgraph - if name in node_dict: - _graph_update_static( - node_dict[name], - value_updates, - cache, - _StaticModuleStatus.UPDATED, - (*path, name), - ) - else: - # case 3: adding a new subgraph - if isinstance(node_impl, PytreeNodeImpl): - raise ValueError( - f'Cannot set key {name!r} on immutable node of ' - f'type {type(node).__name__}' - ) - - # check if the subgraph is already in the cache - if id(value_updates) in cache: - # if its in the cache, check its status is not NEW - if cache[id(value_updates)] is not _StaticModuleStatus.NEW: - raise ValueError( - f'Trying to add a new node at path {name!r} but a ' - 'node with the same reference has been updated' - ) - else: - cache[id(value_updates)] = _StaticModuleStatus.NEW - - node_impl.set_key(node, name, value_updates) - else: # static field - if isinstance(node_impl, PytreeNodeImpl): - if name in node_dict and node_dict[name] == value_updates: - # if the value is the same, skip - continue - # if trying - raise ValueError( - f'Cannot update key {name!r} on immutable node of ' - f'type {type(node).__name__}. Current value is {node_dict[name]!r}, ' - f'new value is {value_updates!r}.' - ) - - node_impl.set_key(node, name, value_updates) - - # -------------------------------------------------------- # UpdateContext # -------------------------------------------------------- diff --git a/flax/nnx/helpers.py b/flax/nnx/helpers.py index cf8e44dc0d..96622f0e40 100644 --- a/flax/nnx/helpers.py +++ b/flax/nnx/helpers.py @@ -20,7 +20,6 @@ import jax.numpy as jnp import optax -from flax.nnx.graph import Key from flax.nnx.module import GraphDef, Module from flax.nnx.proxy_caller import ApplyCaller from flax.nnx.rnglib import Rngs @@ -63,53 +62,6 @@ def __iter__(self) -> tp.Iterator[str]: def __len__(self) -> int: return len(vars(self)) - -class List(Module, tp.Generic[A]): - def __init__(self, elems: tp.Iterable[A], /): - i = 0 - for i, value in enumerate(elems): - setattr(self, str(i), value) - self._length = i + 1 - - def __getitem__(self, key: int) -> A: - if key >= len(self) or key < -len(self): - raise IndexError(f'index {key} out of range for {self}') - if key < 0: - key = self._length + key - return getattr(self, str(key)) - - def __setitem__(self, key: int, value: A): - if key >= len(self): - raise IndexError(f'index {key} out of range for {self}') - setattr(self, str(key), value) - - def __iter__(self) -> tp.Iterator[A]: - for i in range(len(self)): - yield getattr(self, str(i)) - - def __len__(self) -> int: - return self._length - - def _graph_node_flatten(self): - nodes: list[tuple[Key, tp.Any]] = sorted( - (int(key), value) - for key, value in vars(self).items() - if key not in ('_object__state', '_length') - ) - nodes.append(('_length', self._length)) - return nodes, (type(self), self._object__state._initializing) - - def _graph_node_set_key(self, key: Key, value: tp.Any): - if isinstance(key, int): - key = str(key) - return super()._graph_node_set_key(key, value) - - def _graph_node_pop_key(self, key: Key): - if isinstance(key, int): - key = str(key) - return super()._graph_node_pop_key(key) - - class Sequential(Module): def __init__(self, *fns: tp.Callable[..., tp.Any]): self.layers = list(fns) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 57b0f2e3c1..2a00b5db9d 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -23,6 +23,26 @@ import jax import jax.numpy as jnp +class List(nnx.Module): + def __init__(self, items): + self.items = list(items) + + def __getitem__(self, idx): + return self.items[idx] + + def __setitem__(self, idx, value): + self.items[idx] = value + + +class Dict(nnx.Module): + def __init__(self, *args, **kwargs): + self.items = dict(*args, **kwargs) + + def __getitem__(self, key): + return self.items[key] + + def __setitem__(self, key, value): + self.items[key] = value class StatefulLinear(nnx.Module): def __init__(self, din, dout, rngs): @@ -54,8 +74,8 @@ def test_flatten(self): assert g[3] in refmap def test_unflatten(self): - a = nnx.Dict(a=1, b=nnx.Param(2)) - g = nnx.List([a, 3, a, nnx.Param(4)]) + a = Dict(a=1, b=nnx.Param(2)) + g = List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.split(g) g = nnx.merge(graphdef, state) @@ -72,8 +92,8 @@ def test_unflatten_pytree(self): assert g[0] is not g[2] def test_unflatten_empty(self): - a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) - g = nnx.List([a, 3, a, nnx.Param(4)]) + a = Dict({'a': 1, 'b': nnx.Param(2)}) + g = List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.split(g) @@ -92,46 +112,6 @@ def test_update_dynamic(self): assert g[0]['b'].value == 3 assert g[2]['b'].value == 3 - def test_update_static(self): - a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) - g = nnx.List([a, 3, a, nnx.Param(4)]) - - g2 = nnx.graph.clone(g) - g2[0]['a'] = 5 - - nnx.graph.graph_update_static(g, g2) - - assert g[0]['a'] == 5 - assert g[2]['a'] == 5 - - def test_update_static_inconsistent_types(self): - a = {'a': 1, 'b': nnx.Param(2)} - g = [a, 3, a, nnx.Param(4)] - g2 = [a, a, 3, nnx.Param(4)] - - with self.assertRaisesRegex( - ValueError, 'Trying to update a node with a different type' - ): - nnx.graph.graph_update_static(g, g2) - - def test_update_static_add_new(self): - a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) - b = nnx.List([5, 6]) - g = nnx.List([a, 3, a, nnx.Param(4)]) - g2 = nnx.List([a, 3, a, nnx.Param(4), b]) - - nnx.graph.graph_update_static(g, g2) - - assert g[4][0] == 5 - assert g[4][1] == 6 - - def test_update_static_add_shared_error(self): - a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) - g = nnx.List([a, 3, a, nnx.Param(4)]) - g2 = nnx.List([a, 3, a, nnx.Param(4), a]) - - with self.assertRaisesRegex(ValueError, 'Trying to add a new node at path'): - nnx.graph.graph_update_static(g, g2) def test_module_list(self): rngs = nnx.Rngs(0) diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index a3f7bf8c22..2aff69a144 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -14,7 +14,7 @@ from copy import deepcopy import dataclasses -from typing import Any, TypeVar +from typing import TypeVar from absl.testing import absltest from flax import nnx, errors @@ -24,6 +24,35 @@ A = TypeVar('A') +class List(nnx.Module): + def __init__(self, items): + self.items = list(items) + + def __getitem__(self, idx): + return self.items[idx] + + def __setitem__(self, idx, value): + self.items[idx] = value + + +class Dict(nnx.Module): + def __init__(self, *args, **kwargs): + self.items = dict(*args, **kwargs) + + def __getitem__(self, key): + return vars(self)['items'][key] + + def __setitem__(self, key, value): + vars(self)['items'][key] = value + + def __getattr__(self, key): + attrs = vars(self) + if 'items' not in attrs: + raise AttributeError('items') + elif key not in attrs['items']: + raise AttributeError(key) + return attrs['items'][key] + class TestModule(absltest.TestCase): def test_has_module_state(self): @@ -34,7 +63,7 @@ class Foo(nnx.Module): ... assert hasattr(foo, '_object__state') def test_trace_level(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @jax.jit def f(): @@ -47,24 +76,24 @@ def f(): f() def test_tree_map(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) graphdef, state = nnx.split(m) state = jax.tree.map(lambda x: x + 1, state) def test_split_2(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) graphdef, empty, some = nnx.split(m, None, ...) some = jax.tree.map(lambda x: x + 1, some) def test_split_merge(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @jax.jit - def g(graphdef: nnx.GraphDef[nnx.Dict[int]], state: nnx.State): + def g(graphdef: nnx.GraphDef[Dict], state: nnx.State): m = nnx.merge(graphdef, state) m.a = 2 return nnx.split(m) @@ -77,7 +106,7 @@ def g(graphdef: nnx.GraphDef[nnx.Dict[int]], state: nnx.State): def test_no_trace_level_error_on_grad(self): # No trace level error occurs because jax doesn't update # its top trace for grad. - m = nnx.Dict(a=nnx.Param(1.0)) + m = Dict(a=nnx.Param(1.0)) @jax.grad def f(_): @@ -104,8 +133,8 @@ def __call__(self, x, *, rngs: nnx.Rngs): assert isinstance(y, jax.Array) def test_shared_module(self): - m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) - m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) + m1 = Dict(a=nnx.Param(1), b=nnx.Param(2)) + m2 = Dict(x=m1, y=m1, z=nnx.Param(3)) m3 = nnx.merge(*nnx.split(m2)) @@ -131,10 +160,10 @@ def test_deref_through_jit(self): r1 = nnx.Variable(1) r2 = nnx.Variable(2) - m = m0 = nnx.Dict({'a': nnx.List([r1, r2]), 'b': r1}) + m = m0 = Dict({'a': List([r1, r2]), 'b': r1}) @jax.jit - def f(graphdef: nnx.GraphDef[nnx.Dict[Any]], state: nnx.State): + def f(graphdef: nnx.GraphDef[Dict], state: nnx.State): m = nnx.merge(graphdef, state) assert m['a'][0] is m['b'] @@ -154,10 +183,10 @@ def f(graphdef: nnx.GraphDef[nnx.Dict[Any]], state: nnx.State): assert m['b'] is not m0['b'] def test_cross_barrier(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @jax.jit - def g(graphdef: nnx.GraphDef[nnx.Dict[nnx.Param[int]]], state: nnx.State): + def g(graphdef: nnx.GraphDef[Dict], state: nnx.State): m = nnx.merge(graphdef, state) m.a.value += 1 return nnx.split(m) @@ -170,7 +199,7 @@ def g(graphdef: nnx.GraphDef[nnx.Dict[nnx.Param[int]]], state: nnx.State): def test_no_rejit(self): n = 0 - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @jax.jit def g(state_and_def): @@ -202,10 +231,10 @@ def test_deref_number_of_fields(self): r1 = nnx.Variable(1) r2 = nnx.Variable(2) v1 = 3 - m = nnx.Dict( + m = Dict( { - 'a': nnx.List([r1, r2, v1]), - 'b': nnx.Dict({'c': r1, 'd': r2}), + 'a': List([r1, r2, v1]), + 'b': Dict({'c': r1, 'd': r2}), } ) @@ -214,9 +243,9 @@ def test_deref_number_of_fields(self): assert len(jax.tree_util.tree_leaves(p)) == 2 def test_clone(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), 3]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.Param(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), 3]), + b=Dict(c=nnx.Param(1), d=nnx.Param(2)), ) m2 = nnx.clone(m) diff --git a/tests/nnx/partitioning_test.py b/tests/nnx/partitioning_test.py index 92c878cb2e..bb859de3a6 100644 --- a/tests/nnx/partitioning_test.py +++ b/tests/nnx/partitioning_test.py @@ -12,16 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from absl.testing import absltest from flax import nnx import jax +class List(nnx.Module): + def __init__(self, items): + vars(self).update({str(i): item for i, item in enumerate(items)}) + + def __getitem__(self, idx): + return getattr(self, str(idx)) + + def __setitem__(self, idx, value): + setattr(self, str(idx), value) + + +class Dict(nnx.Module): + def __init__(self, *args, **kwargs): + vars(self).update(dict(*args, **kwargs)) + + def __getitem__(self, key): + return vars(self)[key] + + def __setitem__(self, key, value): + vars(self)[key] = value + + if TYPE_CHECKING: + + def __getattr__(self, key): ... + + class TestPartitioning(absltest.TestCase): def test_partition(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.BatchStat(2)]), + m = Dict( + a=List([nnx.Param(1), nnx.BatchStat(2)]), b=nnx.Param(2), c=100, ) @@ -32,41 +59,41 @@ def test_partition(self): self.assertLen(rest, 1) # check params - self.assertEqual(params['a'][0].value, m.a[0].value) + self.assertEqual(params['a']['0'].value, m.a['0'].value) self.assertEqual(params['b'].value, m.b.value) # check rest - self.assertEqual(rest['a'][1].value, m.a[1].value) + self.assertEqual(rest['a']['1'].value, m.a['1'].value) m2 = nnx.merge(graphdef, params, rest) - self.assertEqual(m2.a[0].value, m.a[0].value) - self.assertEqual(m2.a[1].value, m.a[1].value) + self.assertEqual(m2.a['0'].value, m.a['0'].value) + self.assertEqual(m2.a['1'].value, m.a['1'].value) self.assertEqual(m2.b.value, m.b.value) self.assertEqual(m2.c, 100) def test_complete_partitioning(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) # no error nnx.split(m, nnx.Param, nnx.BatchStat, nnx.Variable) def test_complete_partitioning_plus_ellipsis(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) # no error if additional ... is passed at the end nnx.split(m, nnx.Param, nnx.BatchStat, nnx.Variable, ...) def test_inclomplete_partition_error(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) with self.assertRaisesRegex( @@ -75,9 +102,9 @@ def test_inclomplete_partition_error(self): nnx.split(m, nnx.Param) def test_ellipsis_not_last_error(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) with self.assertRaisesRegex( @@ -86,8 +113,8 @@ def test_ellipsis_not_last_error(self): nnx.split(m, ..., nnx.Param) def test_update_from(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.BatchStat(3)]), + m = Dict( + a=List([nnx.Param(1), nnx.BatchStat(3)]), b=nnx.Param(2), c=100, ) @@ -105,8 +132,8 @@ def test_update_from(self): self.assertEqual(m.c, 100) def test_update_from_with_array_leaf(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.BatchStat(3)]), + m = Dict( + a=List([nnx.Param(1), nnx.BatchStat(3)]), b=nnx.Param(2), c=nnx.Variable(jax.numpy.array(100)), ) @@ -124,8 +151,8 @@ def test_update_from_with_array_leaf(self): self.assertEqual(m.c.value, 200) def test_grad_example(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1.0), nnx.BatchStat(-10)]), + m = Dict( + a=List([nnx.Param(1.0), nnx.BatchStat(-10)]), b=nnx.Param(2.0), c=100, ) @@ -144,8 +171,8 @@ def loss(params): self.assertEqual(m.c, 100) def test_get_paritition(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(10.0), nnx.Param(20.0)]), + m = Dict( + a=List([nnx.Param(10.0), nnx.Param(20.0)]), b=nnx.Param(10.0), c=7, d=5.0, @@ -155,10 +182,10 @@ def test_get_paritition(self): self.assertIsNot(vars(m.a)['0'], vars(m)['b']) state = nnx.state(m, nnx.Variable) - self.assertEqual(state['a'][0].value, m.a[0].value) - self.assertEqual(state['a'][1].value, m.a[1].value) + self.assertEqual(state['a']['0'].value, m.a['0'].value) + self.assertEqual(state['a']['1'].value, m.a['1'].value) self.assertEqual(state['b'].value, m.b.value) - self.assertIsNot(state.b, state.a[0]) + self.assertIsNot(state.b, state.a['0']) self.assertLen(state.flat_state(), 3) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 824e7b6b0e..84a833041c 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -26,12 +26,38 @@ import numpy as np +class List(nnx.Module): + def __init__(self, items): + vars(self).update({str(i): item for i, item in enumerate(items)}) + + def __getitem__(self, idx): + return getattr(self, str(idx)) + + def __setitem__(self, idx, value): + setattr(self, str(idx), value) + + +class Dict(nnx.Module): + def __init__(self, *args, **kwargs): + vars(self).update(dict(*args, **kwargs)) + + def __getitem__(self, key): + return vars(self)[key] + + def __setitem__(self, key, value): + vars(self)[key] = value + + if tp.TYPE_CHECKING: + + def __getattr__(self, key): ... + + class TestJIT(absltest.TestCase): def test_jit(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @nnx.jit - def g(m: nnx.Dict): + def g(m: Dict): m.a = 2 return 1.0 @@ -354,15 +380,15 @@ def test_grad(self): p1 = nnx.Param(10.0) p2 = nnx.Param(20.0) - m = nnx.Dict( - a=nnx.List([p1, p2]), + m = Dict( + a=List([p1, p2]), b=p1, c=7, d=5.0, ) @nnx.grad - def f(m: nnx.Dict): + def f(m: Dict): # sum all params return m['a'][0].value + m['a'][1].value + m['b'].value @@ -370,10 +396,10 @@ def f(m: nnx.Dict): assert m.a[0] is m.b assert isinstance(grads, nnx.State) - assert grads['a'][0].value == 2.0 - assert issubclass(grads.a[0].type, nnx.Variable) - assert grads['a'][1].value == 1.0 - assert issubclass(grads.a[1].type, nnx.Variable) + assert grads['a']['0'].value == 2.0 + assert issubclass(grads.a['0'].type, nnx.Variable) + assert grads['a']['1'].value == 1.0 + assert issubclass(grads.a['1'].type, nnx.Variable) assert len(grads.flat_state()) == 2 nnx.update(m, grads) @@ -386,57 +412,57 @@ def f(m: nnx.Dict): assert m['d'] == 5.0 def test_grad_with_multiple_ref_types(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(10.0), nnx.BatchStat(20.0)]), + m = Dict( + a=List([nnx.Param(10.0), nnx.BatchStat(20.0)]), b=nnx.Param(10.0), c=7, d=5.0, ) @nnx.grad - def f(m: nnx.Dict): + def f(m: Dict): # sum all params return m.a[0].value + m.a[1].value + m.b.value grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a'][0].value == 1.0 - assert issubclass(grads.a[0].type, nnx.Param) + assert grads['a']['0'].value == 1.0 + assert issubclass(grads.a['0'].type, nnx.Param) assert len(grads) == 2 nnx.update(m, grads) - assert m.a[0].value == 1.0 - assert m.a[1].value == 20.0 + assert m.a['0'].value == 1.0 + assert m.a['1'].value == 20.0 assert m.b.value == 1.0 assert m.c == 7 assert m.d == 5.0 def test_grad_with_type_predicate(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(10.0), nnx.BatchStat(20.0)]), + m = Dict( + a=List([nnx.Param(10.0), nnx.BatchStat(20.0)]), b=nnx.Param(10.0), c=7, d=5.0, ) @nnx.grad(argnums=nnx.DiffState(0, nnx.BatchStat)) - def f(m: nnx.Dict): + def f(m: Dict): # sum all params return m.a[0].value + m.a[1].value + m.b.value grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a'][1].value == 1.0 - assert issubclass(grads.a[1].type, nnx.BatchStat) + assert grads['a']['1'].value == 1.0 + assert issubclass(grads.a['1'].type, nnx.BatchStat) assert len(grads) == 1 nnx.update(m, grads) - assert m.a[0].value == 10.0 - assert m.a[1].value == 1.0 + assert m.a['0'].value == 10.0 + assert m.a['1'].value == 1.0 assert m.b.value == 10.0 assert m.c == 7 assert m.d == 5.0