Skip to content

Commit

Permalink
Merge pull request #4230 from IvyZX:pure-dict
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679230175
  • Loading branch information
Flax Authors committed Sep 26, 2024
2 parents 8b37d1a + 3b9dce1 commit 9f44f81
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
24 changes: 24 additions & 0 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
V = tp.TypeVar('V')

FlatState = dict[PathParts, V]
ExtractValueFn = tp.Callable[[tp.Any], tp.Any]
SetValueFn = tp.Callable[[V, tp.Any], V]


class NestedStateRepr(reprlib.Representable):
Expand Down Expand Up @@ -158,6 +160,28 @@ def from_flat_path(
nested_state = traversals.unflatten_mapping(flat_state)
return cls(nested_state)

def to_pure_dict(self,
extract_fn: ExtractValueFn | None = None
) -> dict[str, tp.Any]:
# Works for nnx.Variable and nnx.VariableState
if extract_fn is None:
extract_fn = lambda x: x.value if hasattr(x, 'value') else x
flat_values = {k: extract_fn(x) for k, x in self.flat_state().items()}
return traversals.unflatten_mapping(flat_values)

def replace_by_pure_dict(self,
pure_dict: dict[str, tp.Any],
replace_fn: SetValueFn | None = None):
# Works for nnx.Variable and nnx.VariableState
if replace_fn is None:
replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v
current_flat = self.flat_state()
for kp, v in traversals.flatten_mapping(pure_dict).items():
if kp not in current_flat:
raise ValueError(f'key in pure_dict not available in state: {kp}')
current_flat[kp] = replace_fn(current_flat[kp], v)
self.update(traversals.unflatten_mapping(current_flat))

@tp.overload
def split(self, first: filterlib.Filter, /) -> State[K, V]: ...

Expand Down
17 changes: 17 additions & 0 deletions tests/nnx/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from absl.testing import absltest

from flax import nnx
import jax
from jax import numpy as jnp


class StateTest(absltest.TestCase):
Expand Down Expand Up @@ -75,6 +77,21 @@ def __init__(self, *, rngs: nnx.Rngs):
assert module.layers[1].kernel.value.shape == (2, 3)
assert state.layers[1].kernel.value.shape == (2, 3)

def test_pure_dict(self):
module = nnx.Linear(4, 5, rngs=nnx.Rngs(0))
state = nnx.state(module)
pure_dict = state.to_pure_dict()
assert isinstance(pure_dict, dict)
assert isinstance(pure_dict['kernel'], jax.Array)
assert isinstance(pure_dict['bias'], jax.Array)
state.replace_by_pure_dict(jax.tree.map(jnp.zeros_like, pure_dict))
assert isinstance(state, nnx.State)
assert isinstance(state.kernel, nnx.VariableState)
assert jnp.array_equal(state.kernel.value, jnp.zeros((4, 5)))
assert state.kernel.type == nnx.Param
nnx.update(module, state)
assert jnp.array_equal(module(jnp.ones((3, 4))), jnp.zeros((3, 5)))


if __name__ == '__main__':
absltest.main()

0 comments on commit 9f44f81

Please sign in to comment.