Skip to content

Commit

Permalink
[nnx] optimize graph
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 27, 2024
1 parent 7b50ffe commit 4a65bda
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 37 deletions.
51 changes: 27 additions & 24 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,16 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
pop_key: tp.Callable[[Node, Key], Leaf]
create_empty: tp.Callable[[AuxData], Node]
clear: tp.Callable[[Node], None]
init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None]

def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]):
for key, value in items:
self.set_key(node, key, value)
# def init(self, node: Node, items: tp.Iterable[tuple[Key, Leaf]]):
# for key, value in items:
# self.set_key(node, key, value)


@dataclasses.dataclass(frozen=True, slots=True)
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node]


NodeImpl = tp.Union[
Expand All @@ -137,6 +138,7 @@ def register_graph_node_type(
pop_key: tp.Callable[[Node, Key], Leaf],
create_empty: tp.Callable[[AuxData], Node],
clear: tp.Callable[[Node], None],
init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None],
):
if type in GRAPH_REGISTRY:
raise ValueError(f'Node type {type} is already registered.')
Expand All @@ -148,12 +150,13 @@ def register_graph_node_type(
pop_key=pop_key,
create_empty=create_empty,
clear=clear,
init=init,
)

def register_pytree_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node],
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node],
):
if type in PYTREE_REGISTRY:
raise ValueError(f'Node type {type} is already registered.')
Expand Down Expand Up @@ -202,8 +205,8 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:


class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]):
self._mapping = dict(mapping)
def __init__(self, mapping: tp.Mapping[HA, HB], no_copy: bool = False):
self._mapping = mapping if no_copy else dict(mapping)

def __contains__(self, key: object) -> bool:
return key in self._mapping
Expand Down Expand Up @@ -444,7 +447,7 @@ def _graph_flatten(
flat_state[(*path, key)] = value.to_state()
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type(value), variable_index, HashableMapping(value.get_metadata())
type(value), variable_index, HashableMapping(value._var_metadata)
)
attributes.append(LeafAttribute(key, variabledef))
else:
Expand Down Expand Up @@ -528,7 +531,7 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(nodedef.type)

def _get_children():
children: dict[Key, NodeLeaf | Node] = {}
children: list[tuple[Key, NodeLeaf | Node]] = []
state_keys: set = set(state.keys())

# for every key in attributes there are 6 possible cases:
Expand All @@ -539,28 +542,29 @@ def _get_children():
if key not in state:
# if key is not present create an empty types
if type(attribute) is StaticAttribute:
children[key] = attribute.value
children.append((key, attribute.value))
elif type(attribute) is SubGraphAttribute:
# if the key is a subgraph we create an empty node
subgraphdef = attribute.value
assert not isinstance(subgraphdef, VariableDef)
if isinstance(subgraphdef, NodeRef):
# subgraph exists, take it from the cache
children[key] = index_ref[subgraphdef.index]
children.append((key, index_ref[subgraphdef.index]))
else:
# create a node from an empty state, reasoning:
# * its a node with no state
# * its a node with state but only through references of already
# created nodes
substate = {}
children[key] = _graph_unflatten(
subnode = _graph_unflatten(
subgraphdef, substate, index_ref, index_ref_cache
)
children.append((key, subnode))
elif type(attribute) is LeafAttribute:
variabledef = attribute.value
if variabledef.index in index_ref:
# variable exists, take it from the cache
children[key] = index_ref[variabledef.index]
children.append((key, index_ref[variabledef.index]))
else:
# key for a variable is missing, raise an error
raise ValueError(
Expand All @@ -587,19 +591,20 @@ def _get_children():
subgraphdef = attribute.value

if isinstance(subgraphdef, NodeRef):
children[key] = index_ref[subgraphdef.index]
children.append((key, index_ref[subgraphdef.index]))
else:
children[key] = _graph_unflatten(
subnode = _graph_unflatten(
subgraphdef, value, index_ref, index_ref_cache
)
children.append((key, subnode))

elif type(attribute) is LeafAttribute:
variabledef = attribute.value

if variabledef.index in index_ref:
# add an existing variable
assert isinstance(variabledef, NodeRef)
children[key] = index_ref[variabledef.index]
children.append((key, index_ref[variabledef.index]))
else:
# its a unseen variable, create a new one
assert isinstance(variabledef, VariableDef)
Expand All @@ -626,7 +631,7 @@ def _get_children():
variable = variabledef.type.from_metadata(
value, variabledef.metadata
)
children[key] = variable
children.append((key, variable))
index_ref[variabledef.index] = variable
else:
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
Expand All @@ -651,13 +656,11 @@ def _get_children():
else:
node = node_impl.create_empty(nodedef.metadata)
index_ref[nodedef.index] = node
children = _get_children()
node_impl.init(node, tuple(children.items()))
node_impl.init(node, _get_children())
else:
# if the node type does not support the creation of an empty object it means
# that it cannot reference itself, so we can create its children first
children = _get_children()
node = node_impl.unflatten(tuple(children.items()), nodedef.metadata)
node = node_impl.unflatten(_get_children(), nodedef.metadata)

return node

Expand Down Expand Up @@ -816,7 +819,7 @@ def split(
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index, no_copy=True)
)

return graphdef, *states
Expand Down Expand Up @@ -1006,7 +1009,7 @@ def split(
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index, no_copy=True)
)

self.flatten_end(ref_index)
Expand Down Expand Up @@ -1787,7 +1790,7 @@ def is_pytree_node(x: tp.Any) -> bool:
elif isinstance(x, Variable):
return False
# knon pytree types
elif isinstance(x, (VariableState, State)):
elif type(x) is VariableState or type(x) is State:
return True
else:
return not jax.tree_util.all_leaves((x,))
Expand Down
21 changes: 11 additions & 10 deletions flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from flax.nnx import graph
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key
from flax import errors

G = tp.TypeVar('G', bound='Object')
Expand Down Expand Up @@ -109,10 +108,11 @@ def __init_subclass__(cls) -> None:
graph.register_graph_node_type(
type=cls,
flatten=cls._graph_node_flatten,
set_key=cls._graph_node_set_key,
pop_key=cls._graph_node_pop_key,
set_key=cls._graph_node_set_key, # type: ignore
pop_key=cls._graph_node_pop_key, # type: ignore
create_empty=cls._graph_node_create_empty,
clear=cls._graph_node_clear,
init=cls._graph_node_init, # type: ignore
)

if not tp.TYPE_CHECKING:
Expand Down Expand Up @@ -189,14 +189,12 @@ def __treescope_repr__(self, path, subtree_renderer):

# Graph Definition
def _graph_node_flatten(self):
nodes = sorted(
(key, value)
for key, value in vars(self).items()
if key != '_object__state'
)
nodes = vars(self).copy()
del nodes['_object__state']
nodes = sorted(nodes.items())
return nodes, (type(self), self._object__state._initializing)

def _graph_node_set_key(self, key: Key, value: tp.Any):
def _graph_node_set_key(self, key: str, value: tp.Any):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
elif (
Expand All @@ -208,7 +206,7 @@ def _graph_node_set_key(self, key: Key, value: tp.Any):
else:
setattr(self, key, value)

def _graph_node_pop_key(self, key: Key):
def _graph_node_pop_key(self, key: str):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
return vars(self).pop(key)
Expand All @@ -225,3 +223,6 @@ def _graph_node_clear(self):
module_vars = vars(self)
module_vars.clear()
module_vars['_object__state'] = module_state

def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
vars(self).update(attributes)
5 changes: 2 additions & 3 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def copy_from(self, other: Variable[A]) -> None:
def update_from_state(self, variable_state: VariableState[A]):
vars_self = vars(self)
vars_self['raw_value'] = variable_state.value
vars_self['_var_metadata'] = variable_state.get_metadata().copy()
vars_self['_var_metadata'] = variable_state._var_metadata.copy()

@property
def value(self) -> A:
Expand Down Expand Up @@ -308,8 +308,7 @@ def copy(self: Variable[A]) -> Variable[A]:
return obj

def to_state(self: Variable[A]) -> VariableState[A]:
metadata = self.get_metadata()
return VariableState(type(self), self.raw_value, **metadata)
return VariableState(type(self), self.raw_value, **self._var_metadata)

def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
Expand Down

0 comments on commit 4a65bda

Please sign in to comment.