Skip to content

Commit

Permalink
simplify graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 17, 2024
1 parent 9eb0a61 commit 4f62541
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 270 deletions.
1 change: 0 additions & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from .graph import PureState as PureState
from .object import Object as Object
from .helpers import Dict as Dict
from .helpers import List as List
from .helpers import Sequential as Sequential
from .helpers import TrainState as TrainState
from .module import M as M
Expand Down
105 changes: 0 additions & 105 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import contextlib
import dataclasses
import enum
import functools
import threading
import typing as tp
Expand Down Expand Up @@ -743,110 +742,6 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]):
f'Unsupported update type: {type(value)} for key {key!r}'
)


class _StaticModuleStatus(enum.Enum):
NEW = enum.auto()
UPDATED = enum.auto()


# TODO(cgarciae): remove once transform init are reimplemented
def update_from(node: Node, updates: Node) -> None:
graph_update_static(node, updates)
_, state = split(updates)
update(node, state)


# TODO(cgarciae): remove once transform init are reimplemented
def graph_update_static(node: Node, updates: Node) -> None:
cache: dict[int, _StaticModuleStatus] = {}
_graph_update_static(node, updates, cache, _StaticModuleStatus.UPDATED, ())


def _graph_update_static(
node: Node,
updates: Node,
cache: dict[int, _StaticModuleStatus],
status: _StaticModuleStatus,
path: PathParts,
) -> None:
if type(node) != type(updates):
raise ValueError(
f'Trying to update a node with a different type: '
f'expected {type(node).__name__!r}, '
f'but got {type(updates).__name__!r}'
)
if not is_node(node):
raise ValueError(f'Unsupported node type: {type(node)}')

if id(updates) in cache:
if cache[id(updates)] != status:
str_path = '/'.join(str(p) for p in path)
if status is _StaticModuleStatus.NEW:
raise ValueError(
f'Trying to add a new node at path {str_path!r} but a'
' node with the same reference has been updated'
)
else:
raise ValueError(
f'Trying to update a node at path {str_path!r} but a new'
' node with the same reference has been added'
)
return

cache[id(updates)] = status

node_impl = get_node_impl(node)
node_dict = node_impl.node_dict(node)
updates_dict = node_impl.node_dict(updates)
for name, value_updates in updates_dict.items():
# case 1: trying to update a Variable, skip
if is_state_leaf(value_updates):
continue
elif is_node(value_updates):
# case 2: updating an existing subgraph
if name in node_dict:
_graph_update_static(
node_dict[name],
value_updates,
cache,
_StaticModuleStatus.UPDATED,
(*path, name),
)
else:
# case 3: adding a new subgraph
if isinstance(node_impl, PytreeNodeImpl):
raise ValueError(
f'Cannot set key {name!r} on immutable node of '
f'type {type(node).__name__}'
)

# check if the subgraph is already in the cache
if id(value_updates) in cache:
# if its in the cache, check its status is not NEW
if cache[id(value_updates)] is not _StaticModuleStatus.NEW:
raise ValueError(
f'Trying to add a new node at path {name!r} but a '
'node with the same reference has been updated'
)
else:
cache[id(value_updates)] = _StaticModuleStatus.NEW

node_impl.set_key(node, name, value_updates)
else: # static field
if isinstance(node_impl, PytreeNodeImpl):
if name in node_dict and node_dict[name] == value_updates:
# if the value is the same, skip
continue
# if trying
raise ValueError(
f'Cannot update key {name!r} on immutable node of '
f'type {type(node).__name__}. Current value is {node_dict[name]!r}, '
f'new value is {value_updates!r}.'
)

node_impl.set_key(node, name, value_updates)


# --------------------------------------------------------
# UpdateContext
# --------------------------------------------------------
Expand Down
48 changes: 0 additions & 48 deletions flax/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import jax.numpy as jnp
import optax

from flax.nnx.graph import Key
from flax.nnx.module import GraphDef, Module
from flax.nnx.proxy_caller import ApplyCaller
from flax.nnx.rnglib import Rngs
Expand Down Expand Up @@ -63,53 +62,6 @@ def __iter__(self) -> tp.Iterator[str]:
def __len__(self) -> int:
return len(vars(self))


class List(Module, tp.Generic[A]):
def __init__(self, elems: tp.Iterable[A], /):
i = 0
for i, value in enumerate(elems):
setattr(self, str(i), value)
self._length = i + 1

def __getitem__(self, key: int) -> A:
if key >= len(self) or key < -len(self):
raise IndexError(f'index {key} out of range for {self}')
if key < 0:
key = self._length + key
return getattr(self, str(key))

def __setitem__(self, key: int, value: A):
if key >= len(self):
raise IndexError(f'index {key} out of range for {self}')
setattr(self, str(key), value)

def __iter__(self) -> tp.Iterator[A]:
for i in range(len(self)):
yield getattr(self, str(i))

def __len__(self) -> int:
return self._length

def _graph_node_flatten(self):
nodes: list[tuple[Key, tp.Any]] = sorted(
(int(key), value)
for key, value in vars(self).items()
if key not in ('_object__state', '_length')
)
nodes.append(('_length', self._length))
return nodes, (type(self), self._object__state._initializing)

def _graph_node_set_key(self, key: Key, value: tp.Any):
if isinstance(key, int):
key = str(key)
return super()._graph_node_set_key(key, value)

def _graph_node_pop_key(self, key: Key):
if isinstance(key, int):
key = str(key)
return super()._graph_node_pop_key(key)


class Sequential(Module):
def __init__(self, *fns: tp.Callable[..., tp.Any]):
self.layers = list(fns)
Expand Down
68 changes: 24 additions & 44 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@
import jax
import jax.numpy as jnp

class List(nnx.Module):
def __init__(self, items):
self.items = list(items)

def __getitem__(self, idx):
return self.items[idx]

def __setitem__(self, idx, value):
self.items[idx] = value


class Dict(nnx.Module):
def __init__(self, *args, **kwargs):
self.items = dict(*args, **kwargs)

def __getitem__(self, key):
return self.items[key]

def __setitem__(self, key, value):
self.items[key] = value

class StatefulLinear(nnx.Module):
def __init__(self, din, dout, rngs):
Expand Down Expand Up @@ -54,8 +74,8 @@ def test_flatten(self):
assert g[3] in refmap

def test_unflatten(self):
a = nnx.Dict(a=1, b=nnx.Param(2))
g = nnx.List([a, 3, a, nnx.Param(4)])
a = Dict(a=1, b=nnx.Param(2))
g = List([a, 3, a, nnx.Param(4)])

graphdef, state = nnx.split(g)
g = nnx.merge(graphdef, state)
Expand All @@ -72,8 +92,8 @@ def test_unflatten_pytree(self):
assert g[0] is not g[2]

def test_unflatten_empty(self):
a = nnx.Dict({'a': 1, 'b': nnx.Param(2)})
g = nnx.List([a, 3, a, nnx.Param(4)])
a = Dict({'a': 1, 'b': nnx.Param(2)})
g = List([a, 3, a, nnx.Param(4)])

graphdef, state = nnx.split(g)

Expand All @@ -92,46 +112,6 @@ def test_update_dynamic(self):
assert g[0]['b'].value == 3
assert g[2]['b'].value == 3

def test_update_static(self):
a = nnx.Dict({'a': 1, 'b': nnx.Param(2)})
g = nnx.List([a, 3, a, nnx.Param(4)])

g2 = nnx.graph.clone(g)
g2[0]['a'] = 5

nnx.graph.graph_update_static(g, g2)

assert g[0]['a'] == 5
assert g[2]['a'] == 5

def test_update_static_inconsistent_types(self):
a = {'a': 1, 'b': nnx.Param(2)}
g = [a, 3, a, nnx.Param(4)]
g2 = [a, a, 3, nnx.Param(4)]

with self.assertRaisesRegex(
ValueError, 'Trying to update a node with a different type'
):
nnx.graph.graph_update_static(g, g2)

def test_update_static_add_new(self):
a = nnx.Dict({'a': 1, 'b': nnx.Param(2)})
b = nnx.List([5, 6])
g = nnx.List([a, 3, a, nnx.Param(4)])
g2 = nnx.List([a, 3, a, nnx.Param(4), b])

nnx.graph.graph_update_static(g, g2)

assert g[4][0] == 5
assert g[4][1] == 6

def test_update_static_add_shared_error(self):
a = nnx.Dict({'a': 1, 'b': nnx.Param(2)})
g = nnx.List([a, 3, a, nnx.Param(4)])
g2 = nnx.List([a, 3, a, nnx.Param(4), a])

with self.assertRaisesRegex(ValueError, 'Trying to add a new node at path'):
nnx.graph.graph_update_static(g, g2)

def test_module_list(self):
rngs = nnx.Rngs(0)
Expand Down
Loading

0 comments on commit 4f62541

Please sign in to comment.