Skip to content

Commit

Permalink
Merge pull request #4376 from google:nnx-benchmark
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695939844
  • Loading branch information
Flax Authors committed Nov 13, 2024
2 parents 86ff7af + 347211b commit 480a196
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 34 deletions.
31 changes: 15 additions & 16 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import numpy as np
import typing_extensions as tpe

from flax.core.frozen_dict import FrozenDict
from flax.nnx import filterlib, reprlib
from flax.nnx.proxy_caller import (
ApplyCaller,
Expand Down Expand Up @@ -183,7 +182,7 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
return _node_impl_for_type[x]


class _HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]):
self._mapping = dict(mapping)

Expand All @@ -204,7 +203,7 @@ def __hash__(self) -> int:

def __eq__(self, other: tp.Any) -> bool:
return (
isinstance(other, _HashableMapping) and self._mapping == other._mapping
isinstance(other, HashableMapping) and self._mapping == other._mapping
)

def __repr__(self) -> str:
Expand Down Expand Up @@ -246,7 +245,7 @@ def __treescope_repr__(self, path, subtree_renderer):
class VariableDef(reprlib.Representable):
type: type[Variable]
index: int
metadata: FrozenDict[str, tp.Any]
metadata: HashableMapping[str, tp.Any]

def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
Expand All @@ -272,7 +271,7 @@ def __treescope_repr__(self, path, subtree_renderer):
jax.tree_util.register_static(VariableDef)


@dataclasses.dataclass(frozen=True, repr=False)
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
class NodeDef(GraphDef[Node], reprlib.Representable):
"""A dataclass that denotes the tree structure of a
:class:`Module`. A ``GraphDef`` can be generated by either
Expand All @@ -281,11 +280,11 @@ class NodeDef(GraphDef[Node], reprlib.Representable):
type: tp.Type[Node]
index: int
attributes: tuple[Key, ...]
subgraphs: _HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
static_fields: _HashableMapping[Key, tp.Any]
leaves: _HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
subgraphs: HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
static_fields: HashableMapping[Key, tp.Any]
leaves: HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
metadata: tp.Any
index_mapping: FrozenDict[Index, Index] | None
index_mapping: HashableMapping[Index, Index] | None

@classmethod
def create(
Expand All @@ -303,11 +302,11 @@ def create(
type=type,
index=index,
attributes=attributes,
subgraphs=_HashableMapping(subgraphs),
static_fields=_HashableMapping(static_fields),
leaves=_HashableMapping(leaves),
subgraphs=HashableMapping(subgraphs),
static_fields=HashableMapping(static_fields),
leaves=HashableMapping(leaves),
metadata=metadata,
index_mapping=FrozenDict(index_mapping)
index_mapping=HashableMapping(index_mapping)
if index_mapping is not None
else None,
)
Expand Down Expand Up @@ -424,7 +423,7 @@ def _graph_flatten(
flat_state[(*path, key)] = value.to_state()
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type(value), variable_index, FrozenDict(value.get_metadata())
type(value), variable_index, HashableMapping(value.get_metadata())
)
leaves.append((key, variabledef))
else:
Expand Down Expand Up @@ -794,7 +793,7 @@ def split(
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=FrozenDict(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index)
)

return graphdef, *states
Expand Down Expand Up @@ -984,7 +983,7 @@ def split(
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=FrozenDict(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index)
)

self.flatten_end(ref_index)
Expand Down
15 changes: 8 additions & 7 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx import (
extract,
filterlib,
Expand Down Expand Up @@ -428,7 +427,7 @@ def _custom_vjp_split_fn(
nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False)
tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False)

def _extract_index_mappings(x, *, index_mappings: deque[FrozenDict]):
def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]):
if isinstance(x, graph.NodeDef):
assert x.index_mapping is not None
index_mappings.append(x.index_mapping)
Expand Down Expand Up @@ -466,7 +465,9 @@ def __call__(self, *pure_args):
(args_out, out), ctxtag=self.ctxtag
)
# remove index_mapping from NodeDef's but store them in global context
index_mappings: deque[FrozenDict] = extract.get_broadcast_state(self.ctxtag)
index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state(
self.ctxtag
)

pure_args_out, pure_out = jax.tree.map(
functools.partial(_extract_index_mappings, index_mappings=index_mappings),
Expand Down Expand Up @@ -519,8 +520,8 @@ def __call__(self, *pure_args):

if update_context_active:
# remove index_mapping from NodeDef's but store them in global context
index_mappings: deque[FrozenDict] = extract.get_broadcast_state(
self.ctxtag
index_mappings: deque[graph.HashableMapping] = (
extract.get_broadcast_state(self.ctxtag)
)
pure_args_out, pure_out = jax.tree.map(
functools.partial(
Expand Down Expand Up @@ -631,7 +632,7 @@ def __call__(
for i, x in enumerate(tree_node_args)
if i not in self.jax_nondiff_argnums
)
index_mappings: deque[FrozenDict] = deque()
index_mappings: deque[graph.HashableMapping] = deque()
with extract.broadcast_state(self.ctxtag, index_mappings):
if self.fwd is None or self.bwd is None or self.symbolic_zeros is None:
raise ValueError()
Expand Down Expand Up @@ -663,7 +664,7 @@ def __call__(
# insert index_mappings
def _insert_index_mappings(x):
if isinstance(x, graph.NodeDef):
index_mapping: FrozenDict = index_mappings.popleft()
index_mapping: graph.HashableMapping = index_mappings.popleft()
return dataclasses.replace(x, index_mapping=index_mapping)
return x

Expand Down
16 changes: 9 additions & 7 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def check_carry_same_references(key_path, arg, out):

def _extract_index_mappings(
pure_carry_arg_out,
carry_index_mappings: list[FrozenDict[int, int]],
carry_index_mappings: list[graph.HashableMapping[int, int]],
/,
):
def extract_index_mappings(x):
Expand All @@ -675,7 +675,7 @@ def extract_index_mappings(x):

def _insert_index_mappings(
pure_carry_arg_out,
carry_index_mappings: deque[FrozenDict[int, int]],
carry_index_mappings: deque[graph.HashableMapping[int, int]],
/,
):
def insert_index_mappings(x):
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def __call__(

# next we have to remove all the index_mappings from the NodeDefs
# in the carry outputs because they are not present in the inputs
carry_index_mappings: list[FrozenDict[int, int]] = []
carry_index_mappings: list[graph.HashableMapping[int, int]] = []
pure_carry_arg_out = _extract_index_mappings(
pure_carry_arg_out, carry_index_mappings
)
Expand Down Expand Up @@ -1347,10 +1347,12 @@ def per_node_def(nd: graph.NodeDef | tp.Any):
return

per_node_def(ns._graphdef)
return dataclasses.replace(ns, _graphdef=dataclasses.replace(
ns._graphdef,
index_mapping=FrozenDict(global_index_mapping)
))
return dataclasses.replace(
ns,
_graphdef=dataclasses.replace(
ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping)
),
)

return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))
Expand Down
14 changes: 10 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 480a196

Please sign in to comment.