diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 2339f5c168..d5b0ba34d7 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -110,15 +110,16 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]): pop_key: tp.Callable[[Node, Key], Leaf] create_empty: tp.Callable[[AuxData], Node] clear: tp.Callable[[Node], None] + init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None] - def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]): - for key, value in items: - self.set_key(node, key, value) + # def init(self, node: Node, items: tp.Iterable[tuple[Key, Leaf]]): + # for key, value in items: + # self.set_key(node, key, value) @dataclasses.dataclass(frozen=True, slots=True) class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): - unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node] + unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node] NodeImpl = tp.Union[ @@ -137,6 +138,7 @@ def register_graph_node_type( pop_key: tp.Callable[[Node, Key], Leaf], create_empty: tp.Callable[[AuxData], Node], clear: tp.Callable[[Node], None], + init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None], ): if type in GRAPH_REGISTRY: raise ValueError(f'Node type {type} is already registered.') @@ -148,12 +150,13 @@ def register_graph_node_type( pop_key=pop_key, create_empty=create_empty, clear=clear, + init=init, ) def register_pytree_node_type( type: type, flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]], - unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node], + unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node], ): if type in PYTREE_REGISTRY: raise ValueError(f'Node type {type} is already registered.') @@ -202,8 +205,8 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): - def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]): - self._mapping = dict(mapping) + def __init__(self, mapping: tp.Mapping[HA, HB], no_copy: bool = False): + self._mapping = mapping if no_copy else dict(mapping) def __contains__(self, key: object) -> bool: return key in self._mapping @@ -444,7 +447,7 @@ def _graph_flatten( flat_state[(*path, key)] = value.to_state() variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( - type(value), variable_index, HashableMapping(value.get_metadata()) + type(value), variable_index, HashableMapping(value._var_metadata) ) attributes.append(LeafAttribute(key, variabledef)) else: @@ -528,7 +531,7 @@ def _graph_unflatten( node_impl = get_node_impl_for_type(nodedef.type) def _get_children(): - children: dict[Key, NodeLeaf | Node] = {} + children: list[tuple[Key, NodeLeaf | Node]] = [] state_keys: set = set(state.keys()) # for every key in attributes there are 6 possible cases: @@ -539,28 +542,29 @@ def _get_children(): if key not in state: # if key is not present create an empty types if type(attribute) is StaticAttribute: - children[key] = attribute.value + children.append((key, attribute.value)) elif type(attribute) is SubGraphAttribute: # if the key is a subgraph we create an empty node subgraphdef = attribute.value assert not isinstance(subgraphdef, VariableDef) if isinstance(subgraphdef, NodeRef): # subgraph exists, take it from the cache - children[key] = index_ref[subgraphdef.index] + children.append((key, index_ref[subgraphdef.index])) else: # create a node from an empty state, reasoning: # * its a node with no state # * its a node with state but only through references of already # created nodes substate = {} - children[key] = _graph_unflatten( + subnode = _graph_unflatten( subgraphdef, substate, index_ref, index_ref_cache ) + children.append((key, subnode)) elif type(attribute) is LeafAttribute: variabledef = attribute.value if variabledef.index in index_ref: # variable exists, take it from the cache - children[key] = index_ref[variabledef.index] + children.append((key, index_ref[variabledef.index])) else: # key for a variable is missing, raise an error raise ValueError( @@ -587,11 +591,12 @@ def _get_children(): subgraphdef = attribute.value if isinstance(subgraphdef, NodeRef): - children[key] = index_ref[subgraphdef.index] + children.append((key, index_ref[subgraphdef.index])) else: - children[key] = _graph_unflatten( + subnode = _graph_unflatten( subgraphdef, value, index_ref, index_ref_cache ) + children.append((key, subnode)) elif type(attribute) is LeafAttribute: variabledef = attribute.value @@ -599,7 +604,7 @@ def _get_children(): if variabledef.index in index_ref: # add an existing variable assert isinstance(variabledef, NodeRef) - children[key] = index_ref[variabledef.index] + children.append((key, index_ref[variabledef.index])) else: # its a unseen variable, create a new one assert isinstance(variabledef, VariableDef) @@ -626,7 +631,7 @@ def _get_children(): variable = variabledef.type.from_metadata( value, variabledef.metadata ) - children[key] = variable + children.append((key, variable)) index_ref[variabledef.index] = variable else: raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') @@ -651,13 +656,11 @@ def _get_children(): else: node = node_impl.create_empty(nodedef.metadata) index_ref[nodedef.index] = node - children = _get_children() - node_impl.init(node, tuple(children.items())) + node_impl.init(node, _get_children()) else: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first - children = _get_children() - node = node_impl.unflatten(tuple(children.items()), nodedef.metadata) + node = node_impl.unflatten(_get_children(), nodedef.metadata) return node @@ -816,7 +819,7 @@ def split( if ctx.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(ctx.index_ref, self.ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index) + graphdef, index_mapping=HashableMapping(index_to_index, no_copy=True) ) return graphdef, *states @@ -1006,7 +1009,7 @@ def split( if self.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(self.index_ref, ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index) + graphdef, index_mapping=HashableMapping(index_to_index, no_copy=True) ) self.flatten_end(ref_index) @@ -1787,7 +1790,7 @@ def is_pytree_node(x: tp.Any) -> bool: elif isinstance(x, Variable): return False # knon pytree types - elif isinstance(x, (VariableState, State)): + elif type(x) is VariableState or type(x) is State: return True else: return not jax.tree_util.all_leaves((x,)) diff --git a/flax/nnx/object.py b/flax/nnx/object.py index c63506fc48..afa41cdb7b 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -30,7 +30,6 @@ ) from flax.nnx import graph from flax.nnx.variablelib import Variable, VariableState -from flax.typing import Key from flax import errors G = tp.TypeVar('G', bound='Object') @@ -109,10 +108,11 @@ def __init_subclass__(cls) -> None: graph.register_graph_node_type( type=cls, flatten=cls._graph_node_flatten, - set_key=cls._graph_node_set_key, - pop_key=cls._graph_node_pop_key, + set_key=cls._graph_node_set_key, # type: ignore + pop_key=cls._graph_node_pop_key, # type: ignore create_empty=cls._graph_node_create_empty, clear=cls._graph_node_clear, + init=cls._graph_node_init, # type: ignore ) if not tp.TYPE_CHECKING: @@ -189,14 +189,12 @@ def __treescope_repr__(self, path, subtree_renderer): # Graph Definition def _graph_node_flatten(self): - nodes = sorted( - (key, value) - for key, value in vars(self).items() - if key != '_object__state' - ) + nodes = vars(self).copy() + del nodes['_object__state'] + nodes = sorted(nodes.items()) return nodes, (type(self), self._object__state._initializing) - def _graph_node_set_key(self, key: Key, value: tp.Any): + def _graph_node_set_key(self, key: str, value: tp.Any): if not isinstance(key, str): raise KeyError(f'Invalid key: {key!r}') elif ( @@ -208,7 +206,7 @@ def _graph_node_set_key(self, key: Key, value: tp.Any): else: setattr(self, key, value) - def _graph_node_pop_key(self, key: Key): + def _graph_node_pop_key(self, key: str): if not isinstance(key, str): raise KeyError(f'Invalid key: {key!r}') return vars(self).pop(key) @@ -225,3 +223,6 @@ def _graph_node_clear(self): module_vars = vars(self) module_vars.clear() module_vars['_object__state'] = module_state + + def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]): + vars(self).update(attributes) \ No newline at end of file diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 7af20cdb73..4752a9b7bd 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -216,7 +216,7 @@ def copy_from(self, other: Variable[A]) -> None: def update_from_state(self, variable_state: VariableState[A]): vars_self = vars(self) vars_self['raw_value'] = variable_state.value - vars_self['_var_metadata'] = variable_state.get_metadata().copy() + vars_self['_var_metadata'] = variable_state._var_metadata.copy() @property def value(self) -> A: @@ -308,8 +308,7 @@ def copy(self: Variable[A]) -> Variable[A]: return obj def to_state(self: Variable[A]) -> VariableState[A]: - metadata = self.get_metadata() - return VariableState(type(self), self.raw_value, **metadata) + return VariableState(type(self), self.raw_value, **self._var_metadata) def __nnx_repr__(self): yield reprlib.Object(type=type(self))