Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] support pure dicts #4352

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion flax/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,21 @@ def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
)
return tuple(map(to_predicate, filters))


class HasTag(tp.Protocol):
tag: str


def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]:
return hasattr(x, 'tag')


@dataclasses.dataclass(frozen=True)
class WithTag:
tag: str

def __call__(self, path: PathParts, x: tp.Any):
return hasattr(x, 'tag') and x.tag == self.tag
return _has_tag(x) and x.tag == self.tag

def __repr__(self):
return f'WithTag({self.tag!r})'
Expand Down
158 changes: 91 additions & 67 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from flax.nnx.statelib import FlatState, State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts
from flax.typing import Key, PathParts, is_key_like

A = tp.TypeVar('A')
B = tp.TypeVar('B')
Expand All @@ -43,6 +43,7 @@

HA = tp.TypeVar('HA', bound=tp.Hashable)
HB = tp.TypeVar('HB', bound=tp.Hashable)
KeyT = tp.TypeVar('KeyT', bound=Key)

Index = int
Names = tp.Sequence[int]
Expand Down Expand Up @@ -241,6 +242,35 @@ def __treescope_repr__(self, path, subtree_renderer):

jax.tree_util.register_static(NodeRef)

@dataclasses.dataclass(frozen=True, repr=False)
class VariableDef(reprlib.Representable):
type: type[Variable]
index: int
metadata: FrozenDict[str, tp.Any]

def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('index', self.index)
yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]

return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
'index': self.index,
'metadata': self.metadata,
},
path=path,
subtree_renderer=subtree_renderer,
)


jax.tree_util.register_static(VariableDef)


@dataclasses.dataclass(frozen=True, repr=False)
class NodeDef(GraphDef[Node], reprlib.Representable):
Expand All @@ -253,7 +283,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable):
attributes: tuple[Key, ...]
subgraphs: _HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
static_fields: _HashableMapping[Key, tp.Any]
leaves: _HashableMapping[Key, NodeRef[tp.Any] | None]
leaves: _HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
metadata: tp.Any
index_mapping: FrozenDict[Index, Index] | None

Expand All @@ -265,7 +295,7 @@ def create(
attributes: tuple[Key, ...],
subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]],
static_fields: tp.Iterable[tuple[Key, tp.Any]],
leaves: tp.Iterable[tuple[Key, NodeRef[tp.Any] | None]],
leaves: tp.Iterable[tuple[Key, VariableDef | NodeRef[tp.Any]]],
metadata: tp.Any,
index_mapping: tp.Mapping[Index, Index] | None,
):
Expand Down Expand Up @@ -380,7 +410,7 @@ def _graph_flatten(

subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
static_fields: list[tuple[Key, tp.Any]] = []
leaves: list[tuple[Key, NodeRef | None]] = []
leaves: list[tuple[Key, VariableDef | NodeRef]] = []

values, metadata = node_impl.flatten(node)
for key, value in values:
Expand All @@ -393,10 +423,10 @@ def _graph_flatten(
else:
flat_state[(*path, key)] = value.to_state()
variable_index = ref_index[value] = len(ref_index)
leaves.append((key, NodeRef(type(value), variable_index)))
elif is_state_leaf(value):
flat_state[(*path, key)] = value
leaves.append((key, None))
variabledef = VariableDef(
type(value), variable_index, FrozenDict(value.get_metadata())
)
leaves.append((key, variabledef))
else:
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
Expand All @@ -420,7 +450,7 @@ def _graph_flatten(

def unflatten(
graphdef: GraphDef[Node],
state: GraphState,
state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]],
/,
*,
index_ref: dict[Index, tp.Any] | None = None,
Expand All @@ -441,17 +471,17 @@ def unflatten(
existing graph nodes are mutated to have the new content/topology
specified by the graphdef.
"""
if isinstance(state, State):
state = state.raw_mapping # type: ignore
if index_ref is None:
index_ref = {}
assert isinstance(graphdef, (NodeDef, NodeRef))
node = _graph_unflatten(
graphdef, state.raw_mapping, index_ref, index_ref_cache
)
node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache)
return node

def _graph_unflatten(
nodedef: NodeDef[Node] | NodeRef[Node],
state: tp.Mapping[Key, StateLeaf | tp.Mapping[Key, tp.Any]],
state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]],
index_ref: dict[Index, tp.Any],
index_ref_cache: dict[Index, tp.Any] | None,
) -> Node:
Expand Down Expand Up @@ -480,7 +510,7 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(nodedef.type)

def _get_children():
children: dict[Key, StateLeaf | Node] = {}
children: dict[Key, NodeLeaf | Node] = {}

# NOTE: we could allw adding new StateLeafs here
if unkown_keys := set(state) - set(nodedef.attributes):
Expand All @@ -491,13 +521,13 @@ def _get_children():
# - (3) the key can be a subgraph, a leaf, or a static attribute
for key in nodedef.attributes:
if key not in state:
# TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
# if key is not present create an empty types
if key in nodedef.static_fields:
children[key] = nodedef.static_fields[key]
elif key in nodedef.subgraphs:
# if the key is a subgraph we create an empty node
subgraphdef = nodedef.subgraphs[key]
assert not isinstance(subgraphdef, VariableDef)
if isinstance(subgraphdef, NodeRef):
# subgraph exists, take it from the cache
children[key] = index_ref[subgraphdef.index]
Expand All @@ -511,10 +541,10 @@ def _get_children():
subgraphdef, substate, index_ref, index_ref_cache
)
elif key in nodedef.leaves:
noderef = nodedef.leaves[key]
if noderef is not None and noderef.index in index_ref:
variabledef = nodedef.leaves[key]
if variabledef.index in index_ref:
# variable exists, take it from the cache
children[key] = index_ref[noderef.index]
children[key] = index_ref[variabledef.index]
else:
# key for a variable is missing, raise an error
raise ValueError(
Expand Down Expand Up @@ -546,41 +576,40 @@ def _get_children():
)

elif key in nodedef.leaves:
if not is_state_leaf(value):
raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')

noderef = nodedef.leaves[key]

if noderef is None:
# if the leaf is None, it means that the value was originally
# a non-VariableState leaf, however we allow providing a
# VariableState presumbly created by modifying the State
if isinstance(value, VariableState):
value = value.to_variable()
children[key] = value
elif noderef.index in index_ref:
variabledef = nodedef.leaves[key]

if variabledef.index in index_ref:
# add an existing variable
children[key] = index_ref[noderef.index]
assert isinstance(variabledef, NodeRef)
children[key] = index_ref[variabledef.index]
else:
# its a unseen variable, create a new one
if not isinstance(value, VariableState):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(value)}.'
)
assert isinstance(variabledef, VariableDef)
# when idxmap is present, check if the Varable exists there
# and update existing variables if it does
if index_ref_cache is not None and noderef.index in index_ref_cache:
variable = index_ref_cache[noderef.index]
if (
index_ref_cache is not None
and variabledef.index in index_ref_cache
):
# if variable exists, update it
variable = index_ref_cache[variabledef.index]
if not isinstance(variable, Variable):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(variable)}.'
)
variable.update_from_state(value)
if isinstance(value, VariableState):
variable.update_from_state(value)
else:
variable.raw_value = value
else: # if it doesn't, create a new variable
assert isinstance(value, VariableState)
variable = value.to_variable()
if isinstance(value, VariableState):
variable = value.to_variable()
else:
variable = variabledef.type.from_metadata(
value, variabledef.metadata
)
children[key] = variable
index_ref[noderef.index] = variable
index_ref[variabledef.index] = variable
else:
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')

Expand Down Expand Up @@ -676,7 +705,7 @@ def _graph_pop(
pass


def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]):
def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]):
if not is_node(node):
raise RuntimeError(f'Unsupported type: {type(node)}')

Expand All @@ -703,26 +732,19 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]):
if is_state_leaf(value):
raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
_graph_update_dynamic(current_value, value)
elif isinstance(value, VariableState):
else:
# case 3: state leaf is being updated
if not isinstance(current_value, Variable):
raise ValueError(
f'Trying to update a non-Variable attribute {key!r} with a Variable: '
f'{value!r}'
)
current_value.update_from_state(value)
elif is_state_leaf(value):
# case 4: state field is being updated
if isinstance(node_impl, PytreeNodeImpl):
raise ValueError(
f'Cannot set key {key!r} on immutable node of '
f'type {type(node).__name__}'
)
node_impl.set_key(node, key, value)
else:
raise ValueError(
f'Unsupported update type: {type(value)} for key {key!r}'
)
if isinstance(value, VariableState):
# updated from VariableState
current_value.update_from_state(value)
else:
# updated from raw value
current_value.raw_value = value

# --------------------------------------------------------
# UpdateContext
Expand Down Expand Up @@ -1251,12 +1273,11 @@ def split(
states = _split_state(state, filters)
return graphdef, *states


def merge(
graphdef: GraphDef[A],
state: GraphState,
state: tp.Mapping[KeyT, tp.Any],
/,
*states: GraphState,
*states: tp.Mapping[KeyT, tp.Any],
) -> A:
"""The inverse of :func:`split`.

Expand Down Expand Up @@ -1293,13 +1314,15 @@ def merge(
Returns:
The merged :class:`Module`.
"""
state = GraphState.merge(state, *states)
state = State.merge(state, *states)
node = unflatten(graphdef, state)
return node


def update(node, state: State, /, *states: State) -> None:
"""Update the given graph node with a new :class:`State` in-place.
def update(
node, state: tp.Mapping[KeyT, tp.Any], /, *states: tp.Mapping[KeyT, tp.Any]
) -> None:
"""Update the given graph node with a new state(s) in-place.

Example usage::

Expand All @@ -1325,9 +1348,10 @@ def update(node, state: State, /, *states: State) -> None:
*states: Additional :class:`State` objects.
"""
if states:
state = GraphState.merge(state, *states)

_graph_update_dynamic(node, state.raw_mapping)
state = State.merge(state, *states)
if isinstance(state, State):
state = state.raw_mapping
_graph_update_dynamic(node, state)

def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]:
for path, value in iter_graph(node):
Expand Down Expand Up @@ -1741,7 +1765,7 @@ def _key_path_to_key(key: tp.Any) -> Key:
elif isinstance(
key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
):
if not isinstance(key.key, Key):
if not is_key_like(key.key):
raise ValueError(
f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
)
Expand Down
11 changes: 9 additions & 2 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,26 @@
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
PARTITION_NAME = 'partition_name'

class HasSharding(tp.Protocol):
sharding: tuple[str | None, ...] | None

def add_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A:

def _has_sharding(x: tp.Any) -> tp.TypeGuard[HasSharding]:
return hasattr(x, 'sharding') and x.sharding is not None

def add_axis(tree: A, index: int, params: tp.Mapping) -> A:
axis_name = _get_partition_name(params)

def _add_axis(x: tp.Any):
if isinstance(x, variablelib.VariableState):
if hasattr(x, 'sharding') and x.sharding is not None:
if _has_sharding(x) and x.sharding is not None:
sharding: list[str | None] = list(x.sharding)
while len(sharding) < index:
sharding.append(None)
sharding.insert(index, axis_name)
x.sharding = tuple(sharding) # type: ignore

assert isinstance(x, variablelib.VariableState)
x.add_axis(index, axis_name)
return x

Expand Down
Loading
Loading