Skip to content

Commit

Permalink
simplify graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 26, 2024
1 parent 9eb0a61 commit e7dba28
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 456 deletions.
5 changes: 3 additions & 2 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from .graph import PureState as PureState
from .object import Object as Object
from .helpers import Dict as Dict
from .helpers import List as List
from .helpers import Sequential as Sequential
from .helpers import TrainState as TrainState
from .module import M as M
Expand Down Expand Up @@ -153,4 +152,6 @@
from .variables import VariableMetadata as VariableMetadata
from .variables import with_metadata as with_metadata
from .visualization import display as display
from .extract import to_tree, from_tree, TreeNode
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
69 changes: 16 additions & 53 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,58 +237,25 @@ class GraphDefState(struct.PyTreeNode):
graphdef: graph.GraphDef[tp.Any] = struct.field(pytree_node=False)
state: graph.GraphState = struct.field(pytree_node=True)

class StateOnly(struct.PyTreeNode):
state: graph.GraphState = struct.field(pytree_node=True)

@property
def graphdef(self) -> graph.GraphDef[tp.Any]:
raise ValueError('No graphdef available in StateOnly')


@dataclasses.dataclass(frozen=True)
class StateSequence(tp.Sequence[graph.GraphState]):
graphdef_states: tuple[GraphDefState | StateOnly, ...]

@tp.overload
def __getitem__(self, index: int) -> graph.GraphState: ...
@tp.overload
def __getitem__(self, index: slice) -> 'StateSequence': ...
def __getitem__(self, index):
if isinstance(index, slice):
return StateSequence(self.graphdef_states[index])
elif isinstance(index, int):
return self.graphdef_states[index].state
else:
raise TypeError(f'Invalid index type: {type(index)}')

def __len__(self):
return len(self.graphdef_states)

def __iter__(self):
return (s.state for s in self.graphdef_states)


class TreeNode(struct.PyTreeNode):
metatata: tp.Any = struct.field(pytree_node=False)
graphdef_states: tuple[GraphDefState | StateOnly, ...] = struct.field(
pytree_node=True
)
class NodeStates(struct.PyTreeNode):
_graphdef: graph.GraphDef[tp.Any] | None
states: tuple[graph.GraphState, ...]
metadata: tp.Any = struct.field(pytree_node=False)

@property
def graphdef(self) -> graph.GraphDef[tp.Any]:
return self.graphdef_states[0].graphdef
if self._graphdef is None:
raise ValueError('No graphdef available')
return self._graphdef

@property
def state(self) -> graph.GraphState:
if len(self.graphdef_states) != 1:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.graphdef_states)}'
f'Expected exactly one GraphDefState, got {len(self.states)}'
)
return self.graphdef_states[0].state

@property
def states(self) -> tp.Sequence[graph.GraphState]:
return StateSequence(self.graphdef_states)
return self.states[0]

@classmethod
def from_split(
Expand All @@ -299,15 +266,11 @@ def from_split(
*states: graph.GraphState,
metadata: tp.Any = None,
):
states = (state, *states)
return cls(
metadata, tuple(GraphDefState(graphdef, state) for state in states)
)
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)

@classmethod
def from_states(cls, state: graph.GraphState, *states: graph.GraphState):
states = (state, *states)
return cls(None, tuple(StateOnly(state) for state in states))
return cls(_graphdef=None, states=(state, *states), metadata=None)

@classmethod
def from_prefixes(
Expand All @@ -317,13 +280,13 @@ def from_prefixes(
*,
metadata: tp.Any = None,
):
return cls(metadata, tuple(prefixes))
return cls(_graphdef=None, states=tuple(prefixes), metadata=metadata)


def default_split_fn(
ctx: graph.SplitContext, path: KeyPath, prefix: Prefix, leaf: Leaf
) -> tp.Any:
return TreeNode.from_split(*ctx.split(leaf))
return NodeStates.from_split(*ctx.split(leaf))


def to_tree(
Expand Down Expand Up @@ -370,13 +333,13 @@ def to_tree(
def merge_tree_node(
ctx: graph.MergeContext, path: KeyPath, prefix: Prefix, leaf: Leaf
) -> tp.Any:
if not isinstance(leaf, TreeNode):
if not isinstance(leaf, NodeStates):
raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}')
return ctx.merge(leaf.graphdef, *leaf.states)


def is_tree_node(x):
return isinstance(x, TreeNode)
return isinstance(x, NodeStates)


def from_tree(
Expand Down
147 changes: 10 additions & 137 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import contextlib
import dataclasses
import enum
import functools
import threading
import typing as tp
Expand Down Expand Up @@ -51,9 +50,8 @@
AuxData = tp.TypeVar('AuxData')

StateLeaf = VariableState[tp.Any]
NodeLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
Expand All @@ -64,50 +62,29 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
return isinstance(x, Variable)


class _HashById(tp.Hashable, tp.Generic[A]):
"""A wrapper around a value that uses its id for hashing and equality.
This is used by RefMap to explicitly use object id as the hash for the keys.
"""

__slots__ = ('_value',)

def __init__(self, value: A):
self._value = value

@property
def value(self) -> A:
return self._value

def __hash__(self) -> int:
return id(self._value)

def __eq__(self, other: tp.Any) -> bool:
return isinstance(other, _HashById) and self._value is other._value


class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]):
"""A mapping that uses object id as the hash for the keys."""

def __init__(
self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), /
):
self._mapping: dict[_HashById[A], B] = {}
self._mapping: dict[int, tuple[A, B]] = {}
self.update(mapping)

def __getitem__(self, key: A) -> B:
return self._mapping[_HashById(key)]
return self._mapping[id(key)][1]

def __contains__(self, key: object) -> bool:
return _HashById(key) in self._mapping
return id(key) in self._mapping

def __setitem__(self, key: A, value: B):
self._mapping[_HashById(key)] = value
self._mapping[id(key)] = (key, value)

def __delitem__(self, key: A):
del self._mapping[_HashById(key)]
del self._mapping[id(key)]

def __iter__(self) -> tp.Iterator[A]:
return (x.value for x in self._mapping)
return (key for key, _ in self._mapping.values())

def __len__(self) -> int:
return len(self._mapping)
Expand Down Expand Up @@ -637,7 +614,7 @@ def graph_pop(
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates)
flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates)
_graph_pop(node, id_to_index, path_parts, flat_states, predicates)
return tuple(
GraphState.from_flat_path(flat_state) for flat_state in flat_states
Expand All @@ -648,7 +625,7 @@ def _graph_pop(
node: tp.Any,
id_to_index: dict[int, Index],
path_parts: PathParts,
flat_states: tuple[GraphFlatState, ...],
flat_states: tuple[FlatState[StateLeaf], ...],
predicates: tuple[filterlib.Predicate, ...],
) -> None:
if not is_node(node):
Expand Down Expand Up @@ -743,110 +720,6 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]):
f'Unsupported update type: {type(value)} for key {key!r}'
)


class _StaticModuleStatus(enum.Enum):
NEW = enum.auto()
UPDATED = enum.auto()


# TODO(cgarciae): remove once transform init are reimplemented
def update_from(node: Node, updates: Node) -> None:
graph_update_static(node, updates)
_, state = split(updates)
update(node, state)


# TODO(cgarciae): remove once transform init are reimplemented
def graph_update_static(node: Node, updates: Node) -> None:
cache: dict[int, _StaticModuleStatus] = {}
_graph_update_static(node, updates, cache, _StaticModuleStatus.UPDATED, ())


def _graph_update_static(
node: Node,
updates: Node,
cache: dict[int, _StaticModuleStatus],
status: _StaticModuleStatus,
path: PathParts,
) -> None:
if type(node) != type(updates):
raise ValueError(
f'Trying to update a node with a different type: '
f'expected {type(node).__name__!r}, '
f'but got {type(updates).__name__!r}'
)
if not is_node(node):
raise ValueError(f'Unsupported node type: {type(node)}')

if id(updates) in cache:
if cache[id(updates)] != status:
str_path = '/'.join(str(p) for p in path)
if status is _StaticModuleStatus.NEW:
raise ValueError(
f'Trying to add a new node at path {str_path!r} but a'
' node with the same reference has been updated'
)
else:
raise ValueError(
f'Trying to update a node at path {str_path!r} but a new'
' node with the same reference has been added'
)
return

cache[id(updates)] = status

node_impl = get_node_impl(node)
node_dict = node_impl.node_dict(node)
updates_dict = node_impl.node_dict(updates)
for name, value_updates in updates_dict.items():
# case 1: trying to update a Variable, skip
if is_state_leaf(value_updates):
continue
elif is_node(value_updates):
# case 2: updating an existing subgraph
if name in node_dict:
_graph_update_static(
node_dict[name],
value_updates,
cache,
_StaticModuleStatus.UPDATED,
(*path, name),
)
else:
# case 3: adding a new subgraph
if isinstance(node_impl, PytreeNodeImpl):
raise ValueError(
f'Cannot set key {name!r} on immutable node of '
f'type {type(node).__name__}'
)

# check if the subgraph is already in the cache
if id(value_updates) in cache:
# if its in the cache, check its status is not NEW
if cache[id(value_updates)] is not _StaticModuleStatus.NEW:
raise ValueError(
f'Trying to add a new node at path {name!r} but a '
'node with the same reference has been updated'
)
else:
cache[id(value_updates)] = _StaticModuleStatus.NEW

node_impl.set_key(node, name, value_updates)
else: # static field
if isinstance(node_impl, PytreeNodeImpl):
if name in node_dict and node_dict[name] == value_updates:
# if the value is the same, skip
continue
# if trying
raise ValueError(
f'Cannot update key {name!r} on immutable node of '
f'type {type(node).__name__}. Current value is {node_dict[name]!r}, '
f'new value is {value_updates!r}.'
)

node_impl.set_key(node, name, value_updates)


# --------------------------------------------------------
# UpdateContext
# --------------------------------------------------------
Expand Down Expand Up @@ -1598,7 +1471,7 @@ def pop(
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates)
flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates)
_graph_pop(
node=node,
id_to_index=id_to_index,
Expand Down
Loading

0 comments on commit e7dba28

Please sign in to comment.