From 97c3e9c19acc3955b06d0d2c8efdbc715410e2e9 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 31 Oct 2024 11:42:22 -0500 Subject: [PATCH] [nnx] support pure dicts --- flax/nnx/filterlib.py | 11 ++- flax/nnx/graph.py | 158 ++++++++++++++++++------------- flax/nnx/spmd.py | 11 ++- flax/nnx/statelib.py | 10 +- flax/nnx/transforms/iteration.py | 2 +- flax/nnx/variablelib.py | 12 ++- flax/typing.py | 5 +- tests/nnx/graph_utils_test.py | 26 ++++- 8 files changed, 155 insertions(+), 80 deletions(-) diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py index e702966cfd..63ed371be9 100644 --- a/flax/nnx/filterlib.py +++ b/flax/nnx/filterlib.py @@ -65,12 +65,21 @@ def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]: ) return tuple(map(to_predicate, filters)) + +class HasTag(tp.Protocol): + tag: str + + +def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]: + return hasattr(x, 'tag') + + @dataclasses.dataclass(frozen=True) class WithTag: tag: str def __call__(self, path: PathParts, x: tp.Any): - return hasattr(x, 'tag') and x.tag == self.tag + return _has_tag(x) and x.tag == self.tag def __repr__(self): return f'WithTag({self.tag!r})' diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 6cc5588b1d..3b3565ef29 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -34,7 +34,7 @@ from flax.nnx.statelib import FlatState, State from flax.nnx import variablelib from flax.nnx.variablelib import Variable, VariableState -from flax.typing import Key, PathParts +from flax.typing import Key, PathParts, is_key_like A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -43,6 +43,7 @@ HA = tp.TypeVar('HA', bound=tp.Hashable) HB = tp.TypeVar('HB', bound=tp.Hashable) +KeyT = tp.TypeVar('KeyT', bound=Key) Index = int Names = tp.Sequence[int] @@ -241,6 +242,35 @@ def __treescope_repr__(self, path, subtree_renderer): jax.tree_util.register_static(NodeRef) +@dataclasses.dataclass(frozen=True, repr=False) +class VariableDef(reprlib.Representable): + type: type[Variable] + index: int + metadata: FrozenDict[str, tp.Any] + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + yield reprlib.Attr('type', self.type.__name__) + yield reprlib.Attr('index', self.index) + yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata)) + + def __treescope_repr__(self, path, subtree_renderer): + import treescope # type: ignore[import-not-found,import-untyped] + + return treescope.repr_lib.render_object_constructor( + object_type=type(self), + attributes={ + 'type': self.type, + 'index': self.index, + 'metadata': self.metadata, + }, + path=path, + subtree_renderer=subtree_renderer, + ) + + +jax.tree_util.register_static(VariableDef) + @dataclasses.dataclass(frozen=True, repr=False) class NodeDef(GraphDef[Node], reprlib.Representable): @@ -253,7 +283,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable): attributes: tuple[Key, ...] subgraphs: _HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]] static_fields: _HashableMapping[Key, tp.Any] - leaves: _HashableMapping[Key, NodeRef[tp.Any] | None] + leaves: _HashableMapping[Key, VariableDef | NodeRef[tp.Any]] metadata: tp.Any index_mapping: FrozenDict[Index, Index] | None @@ -265,7 +295,7 @@ def create( attributes: tuple[Key, ...], subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]], static_fields: tp.Iterable[tuple[Key, tp.Any]], - leaves: tp.Iterable[tuple[Key, NodeRef[tp.Any] | None]], + leaves: tp.Iterable[tuple[Key, VariableDef | NodeRef[tp.Any]]], metadata: tp.Any, index_mapping: tp.Mapping[Index, Index] | None, ): @@ -380,7 +410,7 @@ def _graph_flatten( subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = [] static_fields: list[tuple[Key, tp.Any]] = [] - leaves: list[tuple[Key, NodeRef | None]] = [] + leaves: list[tuple[Key, VariableDef | NodeRef]] = [] values, metadata = node_impl.flatten(node) for key, value in values: @@ -393,10 +423,10 @@ def _graph_flatten( else: flat_state[(*path, key)] = value.to_state() variable_index = ref_index[value] = len(ref_index) - leaves.append((key, NodeRef(type(value), variable_index))) - elif is_state_leaf(value): - flat_state[(*path, key)] = value - leaves.append((key, None)) + variabledef = VariableDef( + type(value), variable_index, FrozenDict(value.get_metadata()) + ) + leaves.append((key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): path_str = '/'.join(map(str, (*path, key))) @@ -420,7 +450,7 @@ def _graph_flatten( def unflatten( graphdef: GraphDef[Node], - state: GraphState, + state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], /, *, index_ref: dict[Index, tp.Any] | None = None, @@ -441,17 +471,17 @@ def unflatten( existing graph nodes are mutated to have the new content/topology specified by the graphdef. """ + if isinstance(state, State): + state = state.raw_mapping # type: ignore if index_ref is None: index_ref = {} assert isinstance(graphdef, (NodeDef, NodeRef)) - node = _graph_unflatten( - graphdef, state.raw_mapping, index_ref, index_ref_cache - ) + node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache) return node def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], - state: tp.Mapping[Key, StateLeaf | tp.Mapping[Key, tp.Any]], + state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], index_ref: dict[Index, tp.Any], index_ref_cache: dict[Index, tp.Any] | None, ) -> Node: @@ -480,7 +510,7 @@ def _graph_unflatten( node_impl = get_node_impl_for_type(nodedef.type) def _get_children(): - children: dict[Key, StateLeaf | Node] = {} + children: dict[Key, NodeLeaf | Node] = {} # NOTE: we could allw adding new StateLeafs here if unkown_keys := set(state) - set(nodedef.attributes): @@ -491,13 +521,13 @@ def _get_children(): # - (3) the key can be a subgraph, a leaf, or a static attribute for key in nodedef.attributes: if key not in state: - # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys? # if key is not present create an empty types if key in nodedef.static_fields: children[key] = nodedef.static_fields[key] elif key in nodedef.subgraphs: # if the key is a subgraph we create an empty node subgraphdef = nodedef.subgraphs[key] + assert not isinstance(subgraphdef, VariableDef) if isinstance(subgraphdef, NodeRef): # subgraph exists, take it from the cache children[key] = index_ref[subgraphdef.index] @@ -511,10 +541,10 @@ def _get_children(): subgraphdef, substate, index_ref, index_ref_cache ) elif key in nodedef.leaves: - noderef = nodedef.leaves[key] - if noderef is not None and noderef.index in index_ref: + variabledef = nodedef.leaves[key] + if variabledef.index in index_ref: # variable exists, take it from the cache - children[key] = index_ref[noderef.index] + children[key] = index_ref[variabledef.index] else: # key for a variable is missing, raise an error raise ValueError( @@ -546,41 +576,40 @@ def _get_children(): ) elif key in nodedef.leaves: - if not is_state_leaf(value): - raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}') - - noderef = nodedef.leaves[key] - - if noderef is None: - # if the leaf is None, it means that the value was originally - # a non-VariableState leaf, however we allow providing a - # VariableState presumbly created by modifying the State - if isinstance(value, VariableState): - value = value.to_variable() - children[key] = value - elif noderef.index in index_ref: + variabledef = nodedef.leaves[key] + + if variabledef.index in index_ref: # add an existing variable - children[key] = index_ref[noderef.index] + assert isinstance(variabledef, NodeRef) + children[key] = index_ref[variabledef.index] else: # its a unseen variable, create a new one - if not isinstance(value, VariableState): - raise ValueError( - f'Expected a Variable type for {key!r}, but got {type(value)}.' - ) + assert isinstance(variabledef, VariableDef) # when idxmap is present, check if the Varable exists there # and update existing variables if it does - if index_ref_cache is not None and noderef.index in index_ref_cache: - variable = index_ref_cache[noderef.index] + if ( + index_ref_cache is not None + and variabledef.index in index_ref_cache + ): + # if variable exists, update it + variable = index_ref_cache[variabledef.index] if not isinstance(variable, Variable): raise ValueError( f'Expected a Variable type for {key!r}, but got {type(variable)}.' ) - variable.update_from_state(value) + if isinstance(value, VariableState): + variable.update_from_state(value) + else: + variable.raw_value = value else: # if it doesn't, create a new variable - assert isinstance(value, VariableState) - variable = value.to_variable() + if isinstance(value, VariableState): + variable = value.to_variable() + else: + variable = variabledef.type.from_metadata( + value, variabledef.metadata + ) children[key] = variable - index_ref[noderef.index] = variable + index_ref[variabledef.index] = variable else: raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') @@ -676,7 +705,7 @@ def _graph_pop( pass -def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]): +def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') @@ -703,26 +732,19 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]): if is_state_leaf(value): raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}') _graph_update_dynamic(current_value, value) - elif isinstance(value, VariableState): + else: # case 3: state leaf is being updated if not isinstance(current_value, Variable): raise ValueError( f'Trying to update a non-Variable attribute {key!r} with a Variable: ' f'{value!r}' ) - current_value.update_from_state(value) - elif is_state_leaf(value): - # case 4: state field is being updated - if isinstance(node_impl, PytreeNodeImpl): - raise ValueError( - f'Cannot set key {key!r} on immutable node of ' - f'type {type(node).__name__}' - ) - node_impl.set_key(node, key, value) - else: - raise ValueError( - f'Unsupported update type: {type(value)} for key {key!r}' - ) + if isinstance(value, VariableState): + # updated from VariableState + current_value.update_from_state(value) + else: + # updated from raw value + current_value.raw_value = value # -------------------------------------------------------- # UpdateContext @@ -1251,12 +1273,11 @@ def split( states = _split_state(state, filters) return graphdef, *states - def merge( graphdef: GraphDef[A], - state: GraphState, + state: tp.Mapping[KeyT, tp.Any], /, - *states: GraphState, + *states: tp.Mapping[KeyT, tp.Any], ) -> A: """The inverse of :func:`split`. @@ -1293,13 +1314,15 @@ def merge( Returns: The merged :class:`Module`. """ - state = GraphState.merge(state, *states) + state = State.merge(state, *states) node = unflatten(graphdef, state) return node -def update(node, state: State, /, *states: State) -> None: - """Update the given graph node with a new :class:`State` in-place. +def update( + node, state: tp.Mapping[KeyT, tp.Any], /, *states: tp.Mapping[KeyT, tp.Any] +) -> None: + """Update the given graph node with a new state(s) in-place. Example usage:: @@ -1325,9 +1348,10 @@ def update(node, state: State, /, *states: State) -> None: *states: Additional :class:`State` objects. """ if states: - state = GraphState.merge(state, *states) - - _graph_update_dynamic(node, state.raw_mapping) + state = State.merge(state, *states) + if isinstance(state, State): + state = state.raw_mapping + _graph_update_dynamic(node, state) def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]: for path, value in iter_graph(node): @@ -1741,7 +1765,7 @@ def _key_path_to_key(key: tp.Any) -> Key: elif isinstance( key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) ): - if not isinstance(key.key, Key): + if not is_key_like(key.key): raise ValueError( f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.' ) diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index fd9deb89f8..b6995b136d 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -31,19 +31,26 @@ F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) PARTITION_NAME = 'partition_name' +class HasSharding(tp.Protocol): + sharding: tuple[str | None, ...] | None -def add_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A: + +def _has_sharding(x: tp.Any) -> tp.TypeGuard[HasSharding]: + return hasattr(x, 'sharding') and x.sharding is not None + +def add_axis(tree: A, index: int, params: tp.Mapping) -> A: axis_name = _get_partition_name(params) def _add_axis(x: tp.Any): if isinstance(x, variablelib.VariableState): - if hasattr(x, 'sharding') and x.sharding is not None: + if _has_sharding(x) and x.sharding is not None: sharding: list[str | None] = list(x.sharding) while len(sharding) < index: sharding.append(None) sharding.insert(index, axis_name) x.sharding = tuple(sharding) # type: ignore + assert isinstance(x, variablelib.VariableState) x.add_axis(index, axis_name) return x diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 2442035f85..df299ea54d 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -319,7 +319,9 @@ def filter( return states # type: ignore[bad-return-type] @staticmethod - def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]: + def merge( + state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V] + ) -> State[K, V]: """The inverse of :meth:`split() `. ``merge`` takes one or more ``State``'s and creates @@ -352,14 +354,16 @@ def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]: The merged ``State``. """ if not states: - return state + if isinstance(state, State): + return state + return State(state) states = (state, *states) new_state: FlatState[V] = {} for state in states: - new_state.update(state.flat_state()) # type: ignore[attribute-error] # pytype is wrong here + new_state.update(traversals.flatten_mapping(state)) # type: ignore[attribute-error] # pytype is wrong here return State.from_flat_path(new_state) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index a59fdbd8fa..20366c3e1f 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -1342,7 +1342,7 @@ def per_node_def(nd: graph.NodeDef | tp.Any): for sub_nd in nd.subgraphs.values(): per_node_def(sub_nd) for l in nd.leaves.values(): - if isinstance(l, graph.NodeRef) and l.index >= 0: + if isinstance(l, (graph.VariableDef, graph.NodeRef)) and l.index >= 0: global_index_mapping[l.index] = l.index return diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 26ef67745c..91d6c861d9 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -351,6 +351,14 @@ def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]: vars(obj).update(attributes) return obj + @classmethod + def from_metadata(cls, value: A, attributes: tp.Mapping[str, tp.Any]): + obj = object.__new__(cls) + vars(obj).update( + attributes, raw_value=value, _trace_state=tracers.TraceState() + ) + return obj + def copy(self: Variable[A]) -> Variable[A]: obj = object.__new__(type(self)) attributes = vars(self).copy() @@ -359,9 +367,7 @@ def copy(self: Variable[A]) -> Variable[A]: return obj def to_state(self: Variable[A]) -> VariableState[A]: - metadata = vars(self).copy() - del metadata['raw_value'] - del metadata['_trace_state'] + metadata = self.get_metadata() return VariableState(type(self), self.raw_value, **metadata) def __nnx_repr__(self): diff --git a/flax/typing.py b/flax/typing.py index 964de057db..7200095319 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -19,9 +19,9 @@ Generic, Optional, Protocol, + TypeGuard, TypeVar, Union, - runtime_checkable, ) from collections.abc import Callable, Hashable, Mapping, Sequence @@ -41,11 +41,12 @@ Shape = Sequence[int] K = TypeVar('K') -@runtime_checkable class Key(Hashable, Protocol): def __lt__(self: K, value: K, /) -> bool: ... +def is_key_like(x: Any) -> TypeGuard[Key]: + return hasattr(x, '__hash__') and hasattr(x, '__lt__') Path = str PathParts = tuple[Key, ...] diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index bfbb70465d..8983acbe7f 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -82,6 +82,17 @@ def test_unflatten(self): assert g[0] is g[2] + def test_unflatten_pure_dict(self): + a = Dict(a=1, b=nnx.Param(2)) + g = List([a, 3, a, nnx.Param(4)]) + + graphdef, state = nnx.split(g) + pure_state = state.to_pure_dict() + + g = nnx.merge(graphdef, pure_state) + + assert g[0] is g[2] + def test_unflatten_pytree(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] @@ -107,7 +118,20 @@ def test_update_dynamic(self): graphdef, state = nnx.split(g) state[0]['b'].value = 3 - nnx.graph.update(g, state) + nnx.update(g, state) + + assert g[0]['b'].value == 3 + assert g[2]['b'].value == 3 + + def test_update_from_pure_dict(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + graphdef, state = nnx.split(g) + pure_state = state.to_pure_dict() + + pure_state[0]['b'] = 3 + nnx.update(g, pure_state) assert g[0]['b'].value == 3 assert g[2]['b'].value == 3