diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index e532632676..04554ea7d4 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -35,7 +35,6 @@ from .graph import PureState as PureState from .object import Object as Object from .helpers import Dict as Dict -from .helpers import List as List from .helpers import Sequential as Sequential from .helpers import TrainState as TrainState from .module import M as M @@ -153,4 +152,6 @@ from .variables import VariableMetadata as VariableMetadata from .variables import with_metadata as with_metadata from .visualization import display as display -from .extract import to_tree, from_tree, TreeNode +from .extract import to_tree as to_tree +from .extract import from_tree as from_tree +from .extract import NodeStates as NodeStates diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 845544c307..948d8f9fce 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -237,58 +237,25 @@ class GraphDefState(struct.PyTreeNode): graphdef: graph.GraphDef[tp.Any] = struct.field(pytree_node=False) state: graph.GraphState = struct.field(pytree_node=True) -class StateOnly(struct.PyTreeNode): - state: graph.GraphState = struct.field(pytree_node=True) - - @property - def graphdef(self) -> graph.GraphDef[tp.Any]: - raise ValueError('No graphdef available in StateOnly') - - -@dataclasses.dataclass(frozen=True) -class StateSequence(tp.Sequence[graph.GraphState]): - graphdef_states: tuple[GraphDefState | StateOnly, ...] - - @tp.overload - def __getitem__(self, index: int) -> graph.GraphState: ... - @tp.overload - def __getitem__(self, index: slice) -> 'StateSequence': ... - def __getitem__(self, index): - if isinstance(index, slice): - return StateSequence(self.graphdef_states[index]) - elif isinstance(index, int): - return self.graphdef_states[index].state - else: - raise TypeError(f'Invalid index type: {type(index)}') - - def __len__(self): - return len(self.graphdef_states) - def __iter__(self): - return (s.state for s in self.graphdef_states) - - -class TreeNode(struct.PyTreeNode): - metatata: tp.Any = struct.field(pytree_node=False) - graphdef_states: tuple[GraphDefState | StateOnly, ...] = struct.field( - pytree_node=True - ) +class NodeStates(struct.PyTreeNode): + _graphdef: graph.GraphDef[tp.Any] | None + states: tuple[graph.GraphState, ...] + metadata: tp.Any = struct.field(pytree_node=False) @property def graphdef(self) -> graph.GraphDef[tp.Any]: - return self.graphdef_states[0].graphdef + if self._graphdef is None: + raise ValueError('No graphdef available') + return self._graphdef @property def state(self) -> graph.GraphState: - if len(self.graphdef_states) != 1: + if len(self.states) != 1: raise ValueError( - f'Expected exactly one GraphDefState, got {len(self.graphdef_states)}' + f'Expected exactly one GraphDefState, got {len(self.states)}' ) - return self.graphdef_states[0].state - - @property - def states(self) -> tp.Sequence[graph.GraphState]: - return StateSequence(self.graphdef_states) + return self.states[0] @classmethod def from_split( @@ -299,15 +266,11 @@ def from_split( *states: graph.GraphState, metadata: tp.Any = None, ): - states = (state, *states) - return cls( - metadata, tuple(GraphDefState(graphdef, state) for state in states) - ) + return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata) @classmethod def from_states(cls, state: graph.GraphState, *states: graph.GraphState): - states = (state, *states) - return cls(None, tuple(StateOnly(state) for state in states)) + return cls(_graphdef=None, states=(state, *states), metadata=None) @classmethod def from_prefixes( @@ -317,13 +280,13 @@ def from_prefixes( *, metadata: tp.Any = None, ): - return cls(metadata, tuple(prefixes)) + return cls(_graphdef=None, states=tuple(prefixes), metadata=metadata) def default_split_fn( ctx: graph.SplitContext, path: KeyPath, prefix: Prefix, leaf: Leaf ) -> tp.Any: - return TreeNode.from_split(*ctx.split(leaf)) + return NodeStates.from_split(*ctx.split(leaf)) def to_tree( @@ -370,13 +333,13 @@ def to_tree( def merge_tree_node( ctx: graph.MergeContext, path: KeyPath, prefix: Prefix, leaf: Leaf ) -> tp.Any: - if not isinstance(leaf, TreeNode): + if not isinstance(leaf, NodeStates): raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') return ctx.merge(leaf.graphdef, *leaf.states) def is_tree_node(x): - return isinstance(x, TreeNode) + return isinstance(x, NodeStates) def from_tree( diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 5468a5a987..65eccfa906 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -16,7 +16,6 @@ import contextlib import dataclasses -import enum import functools import threading import typing as tp @@ -51,9 +50,8 @@ AuxData = tp.TypeVar('AuxData') StateLeaf = VariableState[tp.Any] -NodeLeaf = VariableState[tp.Any] +NodeLeaf = Variable[tp.Any] GraphState = State[Key, StateLeaf] -GraphFlatState = FlatState[StateLeaf] def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: @@ -64,50 +62,29 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]: return isinstance(x, Variable) -class _HashById(tp.Hashable, tp.Generic[A]): - """A wrapper around a value that uses its id for hashing and equality. - This is used by RefMap to explicitly use object id as the hash for the keys. - """ - - __slots__ = ('_value',) - - def __init__(self, value: A): - self._value = value - - @property - def value(self) -> A: - return self._value - - def __hash__(self) -> int: - return id(self._value) - - def __eq__(self, other: tp.Any) -> bool: - return isinstance(other, _HashById) and self._value is other._value - - class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]): """A mapping that uses object id as the hash for the keys.""" def __init__( self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), / ): - self._mapping: dict[_HashById[A], B] = {} + self._mapping: dict[int, tuple[A, B]] = {} self.update(mapping) def __getitem__(self, key: A) -> B: - return self._mapping[_HashById(key)] + return self._mapping[id(key)][1] def __contains__(self, key: object) -> bool: - return _HashById(key) in self._mapping + return id(key) in self._mapping def __setitem__(self, key: A, value: B): - self._mapping[_HashById(key)] = value + self._mapping[id(key)] = (key, value) def __delitem__(self, key: A): - del self._mapping[_HashById(key)] + del self._mapping[id(key)] def __iter__(self) -> tp.Iterator[A]: - return (x.value for x in self._mapping) + return (key for key, _ in self._mapping.values()) def __len__(self) -> int: return len(self._mapping) @@ -637,7 +614,7 @@ def graph_pop( id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates) + flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates) _graph_pop(node, id_to_index, path_parts, flat_states, predicates) return tuple( GraphState.from_flat_path(flat_state) for flat_state in flat_states @@ -648,7 +625,7 @@ def _graph_pop( node: tp.Any, id_to_index: dict[int, Index], path_parts: PathParts, - flat_states: tuple[GraphFlatState, ...], + flat_states: tuple[FlatState[StateLeaf], ...], predicates: tuple[filterlib.Predicate, ...], ) -> None: if not is_node(node): @@ -743,110 +720,6 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]): f'Unsupported update type: {type(value)} for key {key!r}' ) - -class _StaticModuleStatus(enum.Enum): - NEW = enum.auto() - UPDATED = enum.auto() - - -# TODO(cgarciae): remove once transform init are reimplemented -def update_from(node: Node, updates: Node) -> None: - graph_update_static(node, updates) - _, state = split(updates) - update(node, state) - - -# TODO(cgarciae): remove once transform init are reimplemented -def graph_update_static(node: Node, updates: Node) -> None: - cache: dict[int, _StaticModuleStatus] = {} - _graph_update_static(node, updates, cache, _StaticModuleStatus.UPDATED, ()) - - -def _graph_update_static( - node: Node, - updates: Node, - cache: dict[int, _StaticModuleStatus], - status: _StaticModuleStatus, - path: PathParts, -) -> None: - if type(node) != type(updates): - raise ValueError( - f'Trying to update a node with a different type: ' - f'expected {type(node).__name__!r}, ' - f'but got {type(updates).__name__!r}' - ) - if not is_node(node): - raise ValueError(f'Unsupported node type: {type(node)}') - - if id(updates) in cache: - if cache[id(updates)] != status: - str_path = '/'.join(str(p) for p in path) - if status is _StaticModuleStatus.NEW: - raise ValueError( - f'Trying to add a new node at path {str_path!r} but a' - ' node with the same reference has been updated' - ) - else: - raise ValueError( - f'Trying to update a node at path {str_path!r} but a new' - ' node with the same reference has been added' - ) - return - - cache[id(updates)] = status - - node_impl = get_node_impl(node) - node_dict = node_impl.node_dict(node) - updates_dict = node_impl.node_dict(updates) - for name, value_updates in updates_dict.items(): - # case 1: trying to update a Variable, skip - if is_state_leaf(value_updates): - continue - elif is_node(value_updates): - # case 2: updating an existing subgraph - if name in node_dict: - _graph_update_static( - node_dict[name], - value_updates, - cache, - _StaticModuleStatus.UPDATED, - (*path, name), - ) - else: - # case 3: adding a new subgraph - if isinstance(node_impl, PytreeNodeImpl): - raise ValueError( - f'Cannot set key {name!r} on immutable node of ' - f'type {type(node).__name__}' - ) - - # check if the subgraph is already in the cache - if id(value_updates) in cache: - # if its in the cache, check its status is not NEW - if cache[id(value_updates)] is not _StaticModuleStatus.NEW: - raise ValueError( - f'Trying to add a new node at path {name!r} but a ' - 'node with the same reference has been updated' - ) - else: - cache[id(value_updates)] = _StaticModuleStatus.NEW - - node_impl.set_key(node, name, value_updates) - else: # static field - if isinstance(node_impl, PytreeNodeImpl): - if name in node_dict and node_dict[name] == value_updates: - # if the value is the same, skip - continue - # if trying - raise ValueError( - f'Cannot update key {name!r} on immutable node of ' - f'type {type(node).__name__}. Current value is {node_dict[name]!r}, ' - f'new value is {value_updates!r}.' - ) - - node_impl.set_key(node, name, value_updates) - - # -------------------------------------------------------- # UpdateContext # -------------------------------------------------------- @@ -1598,7 +1471,7 @@ def pop( id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates) + flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates) _graph_pop( node=node, id_to_index=id_to_index, diff --git a/flax/nnx/helpers.py b/flax/nnx/helpers.py index cf8e44dc0d..96622f0e40 100644 --- a/flax/nnx/helpers.py +++ b/flax/nnx/helpers.py @@ -20,7 +20,6 @@ import jax.numpy as jnp import optax -from flax.nnx.graph import Key from flax.nnx.module import GraphDef, Module from flax.nnx.proxy_caller import ApplyCaller from flax.nnx.rnglib import Rngs @@ -63,53 +62,6 @@ def __iter__(self) -> tp.Iterator[str]: def __len__(self) -> int: return len(vars(self)) - -class List(Module, tp.Generic[A]): - def __init__(self, elems: tp.Iterable[A], /): - i = 0 - for i, value in enumerate(elems): - setattr(self, str(i), value) - self._length = i + 1 - - def __getitem__(self, key: int) -> A: - if key >= len(self) or key < -len(self): - raise IndexError(f'index {key} out of range for {self}') - if key < 0: - key = self._length + key - return getattr(self, str(key)) - - def __setitem__(self, key: int, value: A): - if key >= len(self): - raise IndexError(f'index {key} out of range for {self}') - setattr(self, str(key), value) - - def __iter__(self) -> tp.Iterator[A]: - for i in range(len(self)): - yield getattr(self, str(i)) - - def __len__(self) -> int: - return self._length - - def _graph_node_flatten(self): - nodes: list[tuple[Key, tp.Any]] = sorted( - (int(key), value) - for key, value in vars(self).items() - if key not in ('_object__state', '_length') - ) - nodes.append(('_length', self._length)) - return nodes, (type(self), self._object__state._initializing) - - def _graph_node_set_key(self, key: Key, value: tp.Any): - if isinstance(key, int): - key = str(key) - return super()._graph_node_set_key(key, value) - - def _graph_node_pop_key(self, key: Key): - if isinstance(key, int): - key = str(key) - return super()._graph_node_pop_key(key) - - class Sequential(Module): def __init__(self, *fns: tp.Callable[..., tp.Any]): self.layers = list(fns) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 650bd4696c..9e55f70906 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -73,7 +73,7 @@ def __call__(self, *pure_args): nondiff_states: deque[State | None] = extract.get_broadcast_state('grad') def _grad_merge_fn( - ctx: graph.MergeContext, path, prefix, value: extract.TreeNode + ctx: graph.MergeContext, path, prefix, value: extract.NodeStates ): nondiff = nondiff_states.popleft() if nondiff is None: @@ -149,11 +149,11 @@ def _grad_split_fn( ): if prefix is None: nondiff_states.append(None) - return extract.TreeNode.from_split(*ctx.split(value)) + return extract.NodeStates.from_split(*ctx.split(value)) else: graphdef, diff, nondiff = ctx.split(value, prefix.filter, ...) # type: ignore[misc] nondiff_states.append(nondiff) - return extract.TreeNode.from_split(graphdef, diff) + return extract.NodeStates.from_split(graphdef, diff) arg_filters = tuple(index_filter.get(i) for i in range(len(args))) pure_args = extract.to_tree( @@ -165,9 +165,9 @@ def _grad_split_fn( def process_grads(grads): return jax.tree.map( - lambda x: x.state if isinstance(x, extract.TreeNode) else x, + lambda x: x.state if isinstance(x, extract.NodeStates) else x, grads, - is_leaf=lambda x: isinstance(x, extract.TreeNode), + is_leaf=lambda x: isinstance(x, extract.NodeStates), ) def process_out(pure_out: A, /) -> A: @@ -367,7 +367,7 @@ def _custom_vjp_merge_fn( ctx: graph.MergeContext, path, prefix: bool | DiffState, - value: extract.TreeNode, + value: extract.NodeStates, *, nondiff_states: deque[extract.GraphDefState], ): @@ -390,7 +390,7 @@ def _custom_vjp_split_fn( graphdef, passed = ctx.split(value) broadcast = State({}) # type: ignore[var-annotated] nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) - return extract.TreeNode.from_split(graphdef, passed) + return extract.NodeStates.from_split(graphdef, passed) elif prefix is True: # pure differentiable arg, we pass all the state through # but we return a TreeNode.from_states which doesn't have a graphdef @@ -398,7 +398,7 @@ def _custom_vjp_split_fn( graphdef, passed = ctx.split(value) broadcast = State({}) nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) - return extract.TreeNode.from_states(passed) + return extract.NodeStates.from_states(passed) else: # differentiable arg with DiffState filter, we use the filter to split the state # as before we return a TreeNode.from_states to keep the gradients clean @@ -406,7 +406,7 @@ def _custom_vjp_split_fn( # which is broadcasted during the forward pass graphdef, passed, broadcast = ctx.split(value, prefix.filter, ...) # type: ignore[misc] nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) - return extract.TreeNode.from_states(passed) + return extract.NodeStates.from_states(passed) class CustomVjpMetadata(struct.PyTreeNode): @@ -491,9 +491,9 @@ def __call__(self, *args): nondiff = extract.from_tree(nondiff) residual = extract.from_tree(pure_residual) pure_g = jax.tree.map( - lambda x: x.state if isinstance(x, extract.TreeNode) else x, + lambda x: x.state if isinstance(x, extract.NodeStates) else x, pure_g, - is_leaf=lambda x: isinstance(x, extract.TreeNode), + is_leaf=lambda x: isinstance(x, extract.NodeStates), ) tangent = self.bwd(*nondiff, residual, pure_g) @@ -502,7 +502,7 @@ def state_to_tree_node(is_tree_node: bool, x): if is_tree_node: if not isinstance(x, State): raise ValueError(f'Expected State, got {type(x)}') - return extract.TreeNode.from_states(x) + return extract.NodeStates.from_states(x) return x pure_tangent = jax.tree.map( @@ -567,9 +567,9 @@ def __call__( tuple(x for x in arg_filters if x is not False), ) tree_node_args = jax.tree.map( - lambda x: isinstance(x, extract.TreeNode), + lambda x: isinstance(x, extract.NodeStates), pure_args, - is_leaf=lambda x: isinstance(x, extract.TreeNode), + is_leaf=lambda x: isinstance(x, extract.NodeStates), ) tangent_tree_node_args = tuple( arg diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 1f63654d63..88d99e8f7b 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -75,10 +75,10 @@ def __hash__(self): def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x): if isinstance(prefix, StateSharding): - return extract.TreeNode.from_split( + return extract.NodeStates.from_split( *ctx.split(x, *prefix.filters), metadata=prefix ) - return extract.TreeNode.from_split(*ctx.split(x)) + return extract.NodeStates.from_split(*ctx.split(x)) @dataclasses.dataclass(eq=False) @@ -290,13 +290,13 @@ def jit( ) # type: ignore[return-value] kwarg_shardings = None jax_in_shardings = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.shardings, metadata=x) + lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, in_shardings, ) jax_out_shardings = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.shardings, metadata=x) + lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, out_shardings, diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index c169a91fa1..31429e3296 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -97,48 +97,42 @@ def __hash__(self): return hash((self.filters, self.axes)) -AxisFn = tp.Callable[ - [extract.GraphDefState, int, tp.Mapping], extract.GraphDefState -] +AxisFn = tp.Callable[[graph.GraphState, int, tp.Mapping], graph.GraphState] def _update_variable_sharding_metadata( tree, transform_metadata, axis_fn: AxisFn ): - def _update_axes_fn(tree_node): - if isinstance(tree_node, extract.TreeNode) and isinstance( - tree_node.metatata, (StateAxes, int) + def _update_axes_fn(node_states): + if isinstance(node_states, extract.NodeStates) and isinstance( + node_states.metadata, (StateAxes, int) ): - if isinstance(tree_node.metatata, int): - graph_def_state = tree_node.graphdef_states[0] - assert isinstance(graph_def_state, extract.GraphDefState) - graphdef_state = axis_fn( - graph_def_state, tree_node.metatata, transform_metadata - ) - return tree_node.replace(graphdef_states=(graphdef_state,)) + if isinstance(node_states.metadata, int): + state = node_states.state + assert isinstance(state, State) + state = axis_fn(state, node_states.metadata, transform_metadata) + return node_states.replace(states=(state,)) else: - graphdef_states_out: list[extract.GraphDefState] = [] - for graphdef_state, axis in zip( - tree_node.graphdef_states, tree_node.metatata.axes - ): - assert isinstance(graphdef_state, extract.GraphDefState) + states_out: list[graph.GraphState] = [] + for state, axis in zip(node_states.states, node_states.metadata.axes): + assert isinstance(state, graph.State) if isinstance(axis, int): - graphdef_state = axis_fn(graphdef_state, axis, transform_metadata) - graphdef_states_out.append(graphdef_state) - return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) - return tree_node + state = axis_fn(state, axis, transform_metadata) + states_out.append(state) + return node_states.replace(states=tuple(states_out)) + return node_states return jax.tree.map( - _update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.TreeNode) + _update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates) ) def _vmap_split_fn(ctx: graph.SplitContext, path, prefix, x): if isinstance(prefix, StateAxes): - return extract.TreeNode.from_split( - *ctx.split(x, *prefix.filters), metadata=prefix + return extract.NodeStates.from_split( + *ctx.split(x, *prefix.filters), metadata=prefix ) - return extract.TreeNode.from_split(*ctx.split(x), metadata=prefix) + return extract.NodeStates.from_split(*ctx.split(x), metadata=prefix) @dataclasses.dataclass(eq=False) @@ -306,16 +300,16 @@ def vmap( ) # type: ignore[return-value] jax_in_axes = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) - if isinstance(x, StateAxes) - else x, - in_axes, + lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + in_axes, ) jax_out_axes = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) - if isinstance(x, StateAxes) - else x, - out_axes, + lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + out_axes, ) vmapped_fn = jax.vmap( VmapFn(f, transform_metadata, in_axes, out_axes), @@ -526,16 +520,16 @@ def pmap( ) # type: ignore[return-value] jax_in_axes = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) - if isinstance(x, StateAxes) - else x, - in_axes, + lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + in_axes, ) jax_out_axes = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) - if isinstance(x, StateAxes) - else x, - out_axes, + lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + out_axes, ) pmapped_fn = jax.pmap( PmapFn(f, transform_metadata, in_axes, out_axes), @@ -645,21 +639,21 @@ def _extract_index_mappings( /, ): def extract_index_mappings(x): - if isinstance(x, extract.GraphDefState) and isinstance( - x.graphdef, graph.NodeDef + if isinstance(x, extract.NodeStates) and isinstance( + x._graphdef, graph.NodeDef ): - index_mapping = x.graphdef.index_mapping + index_mapping = x._graphdef.index_mapping assert index_mapping is not None carry_index_mappings.append(index_mapping) x = x.replace( - graphdef=dataclasses.replace(x.graphdef, index_mapping=None) + _graphdef=dataclasses.replace(x._graphdef, index_mapping=None) ) return x pure_carry_arg_out = jax.tree.map( extract_index_mappings, pure_carry_arg_out, - is_leaf=lambda x: isinstance(x, extract.GraphDefState), + is_leaf=lambda x: isinstance(x, extract.NodeStates), ) return pure_carry_arg_out @@ -670,19 +664,19 @@ def _insert_index_mappings( /, ): def insert_index_mappings(x): - if isinstance(x, extract.GraphDefState) and isinstance( - x.graphdef, graph.NodeDef + if isinstance(x, extract.NodeStates) and isinstance( + x._graphdef, graph.NodeDef ): index_mapping = carry_index_mappings.popleft() x = x.replace( - graphdef=dataclasses.replace(x.graphdef, index_mapping=index_mapping) + _graphdef=dataclasses.replace(x._graphdef, index_mapping=index_mapping) ) return x pure_carry_arg_out = jax.tree.map( insert_index_mappings, pure_carry_arg_out, - is_leaf=lambda x: isinstance(x, extract.GraphDefState), + is_leaf=lambda x: isinstance(x, extract.NodeStates), ) return pure_carry_arg_out @@ -717,8 +711,8 @@ def _scan_split_in( vectorized_states.append(State({})) carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) - return extract.TreeNode.from_split( - graphdef, *vectorized_states, metadata=prefix + return extract.NodeStates.from_split( + graphdef, *vectorized_states, metadata=prefix ) elif isinstance(prefix, int): graphdef, state = ctx.split(x) @@ -741,8 +735,8 @@ def _scan_split_in( vectorized_states.append(State({})) carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) - return extract.TreeNode.from_split( - graphdef, *vectorized_states, metadata=prefix + return extract.NodeStates.from_split( + graphdef, *vectorized_states, metadata=prefix ) else: if isinstance(prefix, StateAxes): @@ -808,8 +802,8 @@ def _scan_split_out( if is_input_arg: carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) - return extract.TreeNode.from_split( - graphdef, *vectorized_states, metadata=prefix + return extract.NodeStates.from_split( + graphdef, *vectorized_states, metadata=prefix ) elif isinstance(prefix, int): graphdef, state = ctx.split(x) @@ -834,8 +828,8 @@ def _scan_split_out( if is_input_arg: carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) - return extract.TreeNode.from_split( - graphdef, *vectorized_states, metadata=prefix + return extract.NodeStates.from_split( + graphdef, *vectorized_states, metadata=prefix ) else: if isinstance(prefix, StateAxes): @@ -868,7 +862,7 @@ def _scan_merge_in( prefix, x, ): - if isinstance(x, extract.TreeNode): + if isinstance(x, extract.NodeStates): carry_states = carry_deque.popleft() broadcast_states = broadcast_deque.popleft() return ctx.merge(x.graphdef, *x.states, *carry_states, *broadcast_states) @@ -891,7 +885,7 @@ def _scan_merge_out( assert isinstance(path[0], jax.tree_util.SequenceKey) is_input_arg = path[0].idx == 0 - if isinstance(x, extract.TreeNode): + if isinstance(x, extract.NodeStates): states: list[State] = [] if is_input_arg: carry_states = deque(carry_deque.popleft()) @@ -1005,7 +999,7 @@ def __call__( merge_fn=functools.partial( _scan_merge_in, carry_deque, broadcast_deque, broadcast_arrays ), - is_leaf=lambda x: isinstance(x, (extract.TreeNode, Broadcasted)), + is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', ) @@ -1268,7 +1262,7 @@ def scan_wrapper(*args, **kwargs): merge_fn=functools.partial( _scan_merge_out, carry_deque_out, broadcast_deque_out ), - is_leaf=lambda x: isinstance(x, (extract.TreeNode, Broadcasted)), + is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', ) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 57b0f2e3c1..a84f1b1626 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -23,6 +23,26 @@ import jax import jax.numpy as jnp +class List(nnx.Module): + def __init__(self, items): + self.items = list(items) + + def __getitem__(self, idx): + return self.items[idx] + + def __setitem__(self, idx, value): + self.items[idx] = value + + +class Dict(nnx.Module): + def __init__(self, *args, **kwargs): + self.items = dict(*args, **kwargs) + + def __getitem__(self, key): + return self.items[key] + + def __setitem__(self, key, value): + self.items[key] = value class StatefulLinear(nnx.Module): def __init__(self, din, dout, rngs): @@ -54,8 +74,8 @@ def test_flatten(self): assert g[3] in refmap def test_unflatten(self): - a = nnx.Dict(a=1, b=nnx.Param(2)) - g = nnx.List([a, 3, a, nnx.Param(4)]) + a = Dict(a=1, b=nnx.Param(2)) + g = List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.split(g) g = nnx.merge(graphdef, state) @@ -72,8 +92,8 @@ def test_unflatten_pytree(self): assert g[0] is not g[2] def test_unflatten_empty(self): - a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) - g = nnx.List([a, 3, a, nnx.Param(4)]) + a = Dict({'a': 1, 'b': nnx.Param(2)}) + g = List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.split(g) @@ -92,46 +112,6 @@ def test_update_dynamic(self): assert g[0]['b'].value == 3 assert g[2]['b'].value == 3 - def test_update_static(self): - a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) - g = nnx.List([a, 3, a, nnx.Param(4)]) - - g2 = nnx.graph.clone(g) - g2[0]['a'] = 5 - - nnx.graph.graph_update_static(g, g2) - - assert g[0]['a'] == 5 - assert g[2]['a'] == 5 - - def test_update_static_inconsistent_types(self): - a = {'a': 1, 'b': nnx.Param(2)} - g = [a, 3, a, nnx.Param(4)] - g2 = [a, a, 3, nnx.Param(4)] - - with self.assertRaisesRegex( - ValueError, 'Trying to update a node with a different type' - ): - nnx.graph.graph_update_static(g, g2) - - def test_update_static_add_new(self): - a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) - b = nnx.List([5, 6]) - g = nnx.List([a, 3, a, nnx.Param(4)]) - g2 = nnx.List([a, 3, a, nnx.Param(4), b]) - - nnx.graph.graph_update_static(g, g2) - - assert g[4][0] == 5 - assert g[4][1] == 6 - - def test_update_static_add_shared_error(self): - a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) - g = nnx.List([a, 3, a, nnx.Param(4)]) - g2 = nnx.List([a, 3, a, nnx.Param(4), a]) - - with self.assertRaisesRegex(ValueError, 'Trying to add a new node at path'): - nnx.graph.graph_update_static(g, g2) def test_module_list(self): rngs = nnx.Rngs(0) @@ -621,10 +601,10 @@ def test_to_tree_simple(self): t2 = pure_tree[2]['b'] self.assertEqual(pure_tree[1], 1) - self.assertIsInstance(t1, nnx.TreeNode) - assert isinstance(t1, nnx.TreeNode) - self.assertIsInstance(t2, nnx.TreeNode) - assert isinstance(t2, nnx.TreeNode) + self.assertIsInstance(t1, nnx.NodeStates) + assert isinstance(t1, nnx.NodeStates) + self.assertIsInstance(t2, nnx.NodeStates) + assert isinstance(t2, nnx.NodeStates) self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef) self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef) self.assertLen(t1.states[0].flat_state(), 2) @@ -656,10 +636,10 @@ def __init__(self): t2 = pure_tree[2]['b'] self.assertEqual(pure_tree[1], 1) - self.assertIsInstance(t1, nnx.TreeNode) - assert isinstance(t1, nnx.TreeNode) - self.assertIsInstance(t2, nnx.TreeNode) - assert isinstance(t2, nnx.TreeNode) + self.assertIsInstance(t1, nnx.NodeStates) + assert isinstance(t1, nnx.NodeStates) + self.assertIsInstance(t2, nnx.NodeStates) + assert isinstance(t2, nnx.NodeStates) self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef) self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef) self.assertLen(t1.states[0].flat_state(), 1) @@ -683,10 +663,10 @@ def f(pure_tree): t2 = pure_tree2[2]['b'] # self.assertEqual(pure_tree2[1], 1) - self.assertIsInstance(t1, nnx.TreeNode) - assert isinstance(t1, nnx.TreeNode) - self.assertIsInstance(t2, nnx.TreeNode) - assert isinstance(t2, nnx.TreeNode) + self.assertIsInstance(t1, nnx.NodeStates) + assert isinstance(t1, nnx.NodeStates) + self.assertIsInstance(t2, nnx.NodeStates) + assert isinstance(t2, nnx.NodeStates) self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef) self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef) self.assertLen(t1.states[0].flat_state(), 1) @@ -738,19 +718,19 @@ def __init__(self, a, b): m1_axes = StateAxes(None, 0) in_axes = (m1_axes, None, {'b': m1_axes}) jax_in_axes = jax.tree.map( - lambda x: nnx.TreeNode.from_prefixes((x.params, x.batch_stats)) - if isinstance(x, StateAxes) - else x, - in_axes, + lambda x: nnx.NodeStates.from_prefixes((x.params, x.batch_stats)) + if isinstance(x, StateAxes) + else x, + in_axes, ) out_axes = 0 def split_fn(ctx: nnx.SplitContext, path, prefix, x): if isinstance(prefix, StateAxes): - return nnx.TreeNode.from_split( - *ctx.split(x, nnx.Param, nnx.BatchStat) + return nnx.NodeStates.from_split( + *ctx.split(x, nnx.Param, nnx.BatchStat) ) - return nnx.TreeNode.from_split(*ctx.split(x)) + return nnx.NodeStates.from_split(*ctx.split(x)) pure_args = nnx.to_tree( args, ctxtag=ctxtag, prefix=in_axes, split_fn=split_fn diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index a3f7bf8c22..2aff69a144 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -14,7 +14,7 @@ from copy import deepcopy import dataclasses -from typing import Any, TypeVar +from typing import TypeVar from absl.testing import absltest from flax import nnx, errors @@ -24,6 +24,35 @@ A = TypeVar('A') +class List(nnx.Module): + def __init__(self, items): + self.items = list(items) + + def __getitem__(self, idx): + return self.items[idx] + + def __setitem__(self, idx, value): + self.items[idx] = value + + +class Dict(nnx.Module): + def __init__(self, *args, **kwargs): + self.items = dict(*args, **kwargs) + + def __getitem__(self, key): + return vars(self)['items'][key] + + def __setitem__(self, key, value): + vars(self)['items'][key] = value + + def __getattr__(self, key): + attrs = vars(self) + if 'items' not in attrs: + raise AttributeError('items') + elif key not in attrs['items']: + raise AttributeError(key) + return attrs['items'][key] + class TestModule(absltest.TestCase): def test_has_module_state(self): @@ -34,7 +63,7 @@ class Foo(nnx.Module): ... assert hasattr(foo, '_object__state') def test_trace_level(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @jax.jit def f(): @@ -47,24 +76,24 @@ def f(): f() def test_tree_map(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) graphdef, state = nnx.split(m) state = jax.tree.map(lambda x: x + 1, state) def test_split_2(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) graphdef, empty, some = nnx.split(m, None, ...) some = jax.tree.map(lambda x: x + 1, some) def test_split_merge(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @jax.jit - def g(graphdef: nnx.GraphDef[nnx.Dict[int]], state: nnx.State): + def g(graphdef: nnx.GraphDef[Dict], state: nnx.State): m = nnx.merge(graphdef, state) m.a = 2 return nnx.split(m) @@ -77,7 +106,7 @@ def g(graphdef: nnx.GraphDef[nnx.Dict[int]], state: nnx.State): def test_no_trace_level_error_on_grad(self): # No trace level error occurs because jax doesn't update # its top trace for grad. - m = nnx.Dict(a=nnx.Param(1.0)) + m = Dict(a=nnx.Param(1.0)) @jax.grad def f(_): @@ -104,8 +133,8 @@ def __call__(self, x, *, rngs: nnx.Rngs): assert isinstance(y, jax.Array) def test_shared_module(self): - m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) - m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) + m1 = Dict(a=nnx.Param(1), b=nnx.Param(2)) + m2 = Dict(x=m1, y=m1, z=nnx.Param(3)) m3 = nnx.merge(*nnx.split(m2)) @@ -131,10 +160,10 @@ def test_deref_through_jit(self): r1 = nnx.Variable(1) r2 = nnx.Variable(2) - m = m0 = nnx.Dict({'a': nnx.List([r1, r2]), 'b': r1}) + m = m0 = Dict({'a': List([r1, r2]), 'b': r1}) @jax.jit - def f(graphdef: nnx.GraphDef[nnx.Dict[Any]], state: nnx.State): + def f(graphdef: nnx.GraphDef[Dict], state: nnx.State): m = nnx.merge(graphdef, state) assert m['a'][0] is m['b'] @@ -154,10 +183,10 @@ def f(graphdef: nnx.GraphDef[nnx.Dict[Any]], state: nnx.State): assert m['b'] is not m0['b'] def test_cross_barrier(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @jax.jit - def g(graphdef: nnx.GraphDef[nnx.Dict[nnx.Param[int]]], state: nnx.State): + def g(graphdef: nnx.GraphDef[Dict], state: nnx.State): m = nnx.merge(graphdef, state) m.a.value += 1 return nnx.split(m) @@ -170,7 +199,7 @@ def g(graphdef: nnx.GraphDef[nnx.Dict[nnx.Param[int]]], state: nnx.State): def test_no_rejit(self): n = 0 - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @jax.jit def g(state_and_def): @@ -202,10 +231,10 @@ def test_deref_number_of_fields(self): r1 = nnx.Variable(1) r2 = nnx.Variable(2) v1 = 3 - m = nnx.Dict( + m = Dict( { - 'a': nnx.List([r1, r2, v1]), - 'b': nnx.Dict({'c': r1, 'd': r2}), + 'a': List([r1, r2, v1]), + 'b': Dict({'c': r1, 'd': r2}), } ) @@ -214,9 +243,9 @@ def test_deref_number_of_fields(self): assert len(jax.tree_util.tree_leaves(p)) == 2 def test_clone(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), 3]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.Param(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), 3]), + b=Dict(c=nnx.Param(1), d=nnx.Param(2)), ) m2 = nnx.clone(m) diff --git a/tests/nnx/partitioning_test.py b/tests/nnx/partitioning_test.py index 92c878cb2e..bb859de3a6 100644 --- a/tests/nnx/partitioning_test.py +++ b/tests/nnx/partitioning_test.py @@ -12,16 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from absl.testing import absltest from flax import nnx import jax +class List(nnx.Module): + def __init__(self, items): + vars(self).update({str(i): item for i, item in enumerate(items)}) + + def __getitem__(self, idx): + return getattr(self, str(idx)) + + def __setitem__(self, idx, value): + setattr(self, str(idx), value) + + +class Dict(nnx.Module): + def __init__(self, *args, **kwargs): + vars(self).update(dict(*args, **kwargs)) + + def __getitem__(self, key): + return vars(self)[key] + + def __setitem__(self, key, value): + vars(self)[key] = value + + if TYPE_CHECKING: + + def __getattr__(self, key): ... + + class TestPartitioning(absltest.TestCase): def test_partition(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.BatchStat(2)]), + m = Dict( + a=List([nnx.Param(1), nnx.BatchStat(2)]), b=nnx.Param(2), c=100, ) @@ -32,41 +59,41 @@ def test_partition(self): self.assertLen(rest, 1) # check params - self.assertEqual(params['a'][0].value, m.a[0].value) + self.assertEqual(params['a']['0'].value, m.a['0'].value) self.assertEqual(params['b'].value, m.b.value) # check rest - self.assertEqual(rest['a'][1].value, m.a[1].value) + self.assertEqual(rest['a']['1'].value, m.a['1'].value) m2 = nnx.merge(graphdef, params, rest) - self.assertEqual(m2.a[0].value, m.a[0].value) - self.assertEqual(m2.a[1].value, m.a[1].value) + self.assertEqual(m2.a['0'].value, m.a['0'].value) + self.assertEqual(m2.a['1'].value, m.a['1'].value) self.assertEqual(m2.b.value, m.b.value) self.assertEqual(m2.c, 100) def test_complete_partitioning(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) # no error nnx.split(m, nnx.Param, nnx.BatchStat, nnx.Variable) def test_complete_partitioning_plus_ellipsis(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) # no error if additional ... is passed at the end nnx.split(m, nnx.Param, nnx.BatchStat, nnx.Variable, ...) def test_inclomplete_partition_error(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) with self.assertRaisesRegex( @@ -75,9 +102,9 @@ def test_inclomplete_partition_error(self): nnx.split(m, nnx.Param) def test_ellipsis_not_last_error(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), - b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + m = Dict( + a=List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) with self.assertRaisesRegex( @@ -86,8 +113,8 @@ def test_ellipsis_not_last_error(self): nnx.split(m, ..., nnx.Param) def test_update_from(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.BatchStat(3)]), + m = Dict( + a=List([nnx.Param(1), nnx.BatchStat(3)]), b=nnx.Param(2), c=100, ) @@ -105,8 +132,8 @@ def test_update_from(self): self.assertEqual(m.c, 100) def test_update_from_with_array_leaf(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1), nnx.BatchStat(3)]), + m = Dict( + a=List([nnx.Param(1), nnx.BatchStat(3)]), b=nnx.Param(2), c=nnx.Variable(jax.numpy.array(100)), ) @@ -124,8 +151,8 @@ def test_update_from_with_array_leaf(self): self.assertEqual(m.c.value, 200) def test_grad_example(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(1.0), nnx.BatchStat(-10)]), + m = Dict( + a=List([nnx.Param(1.0), nnx.BatchStat(-10)]), b=nnx.Param(2.0), c=100, ) @@ -144,8 +171,8 @@ def loss(params): self.assertEqual(m.c, 100) def test_get_paritition(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(10.0), nnx.Param(20.0)]), + m = Dict( + a=List([nnx.Param(10.0), nnx.Param(20.0)]), b=nnx.Param(10.0), c=7, d=5.0, @@ -155,10 +182,10 @@ def test_get_paritition(self): self.assertIsNot(vars(m.a)['0'], vars(m)['b']) state = nnx.state(m, nnx.Variable) - self.assertEqual(state['a'][0].value, m.a[0].value) - self.assertEqual(state['a'][1].value, m.a[1].value) + self.assertEqual(state['a']['0'].value, m.a['0'].value) + self.assertEqual(state['a']['1'].value, m.a['1'].value) self.assertEqual(state['b'].value, m.b.value) - self.assertIsNot(state.b, state.a[0]) + self.assertIsNot(state.b, state.a['0']) self.assertLen(state.flat_state(), 3) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 824e7b6b0e..84a833041c 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -26,12 +26,38 @@ import numpy as np +class List(nnx.Module): + def __init__(self, items): + vars(self).update({str(i): item for i, item in enumerate(items)}) + + def __getitem__(self, idx): + return getattr(self, str(idx)) + + def __setitem__(self, idx, value): + setattr(self, str(idx), value) + + +class Dict(nnx.Module): + def __init__(self, *args, **kwargs): + vars(self).update(dict(*args, **kwargs)) + + def __getitem__(self, key): + return vars(self)[key] + + def __setitem__(self, key, value): + vars(self)[key] = value + + if tp.TYPE_CHECKING: + + def __getattr__(self, key): ... + + class TestJIT(absltest.TestCase): def test_jit(self): - m = nnx.Dict(a=nnx.Param(1)) + m = Dict(a=nnx.Param(1)) @nnx.jit - def g(m: nnx.Dict): + def g(m: Dict): m.a = 2 return 1.0 @@ -354,15 +380,15 @@ def test_grad(self): p1 = nnx.Param(10.0) p2 = nnx.Param(20.0) - m = nnx.Dict( - a=nnx.List([p1, p2]), + m = Dict( + a=List([p1, p2]), b=p1, c=7, d=5.0, ) @nnx.grad - def f(m: nnx.Dict): + def f(m: Dict): # sum all params return m['a'][0].value + m['a'][1].value + m['b'].value @@ -370,10 +396,10 @@ def f(m: nnx.Dict): assert m.a[0] is m.b assert isinstance(grads, nnx.State) - assert grads['a'][0].value == 2.0 - assert issubclass(grads.a[0].type, nnx.Variable) - assert grads['a'][1].value == 1.0 - assert issubclass(grads.a[1].type, nnx.Variable) + assert grads['a']['0'].value == 2.0 + assert issubclass(grads.a['0'].type, nnx.Variable) + assert grads['a']['1'].value == 1.0 + assert issubclass(grads.a['1'].type, nnx.Variable) assert len(grads.flat_state()) == 2 nnx.update(m, grads) @@ -386,57 +412,57 @@ def f(m: nnx.Dict): assert m['d'] == 5.0 def test_grad_with_multiple_ref_types(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(10.0), nnx.BatchStat(20.0)]), + m = Dict( + a=List([nnx.Param(10.0), nnx.BatchStat(20.0)]), b=nnx.Param(10.0), c=7, d=5.0, ) @nnx.grad - def f(m: nnx.Dict): + def f(m: Dict): # sum all params return m.a[0].value + m.a[1].value + m.b.value grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a'][0].value == 1.0 - assert issubclass(grads.a[0].type, nnx.Param) + assert grads['a']['0'].value == 1.0 + assert issubclass(grads.a['0'].type, nnx.Param) assert len(grads) == 2 nnx.update(m, grads) - assert m.a[0].value == 1.0 - assert m.a[1].value == 20.0 + assert m.a['0'].value == 1.0 + assert m.a['1'].value == 20.0 assert m.b.value == 1.0 assert m.c == 7 assert m.d == 5.0 def test_grad_with_type_predicate(self): - m = nnx.Dict( - a=nnx.List([nnx.Param(10.0), nnx.BatchStat(20.0)]), + m = Dict( + a=List([nnx.Param(10.0), nnx.BatchStat(20.0)]), b=nnx.Param(10.0), c=7, d=5.0, ) @nnx.grad(argnums=nnx.DiffState(0, nnx.BatchStat)) - def f(m: nnx.Dict): + def f(m: Dict): # sum all params return m.a[0].value + m.a[1].value + m.b.value grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a'][1].value == 1.0 - assert issubclass(grads.a[1].type, nnx.BatchStat) + assert grads['a']['1'].value == 1.0 + assert issubclass(grads.a['1'].type, nnx.BatchStat) assert len(grads) == 1 nnx.update(m, grads) - assert m.a[0].value == 10.0 - assert m.a[1].value == 1.0 + assert m.a['0'].value == 10.0 + assert m.a['1'].value == 1.0 assert m.b.value == 10.0 assert m.c == 7 assert m.d == 5.0