diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 5f9f28d957..9063bc8196 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -28,6 +28,8 @@ V = tp.TypeVar('V') FlatState = dict[PathParts, V] +ExtractValueFn = tp.Callable[[tp.Any], tp.Any] +SetValueFn = tp.Callable[[V, tp.Any], V] class NestedStateRepr(reprlib.Representable): @@ -158,6 +160,28 @@ def from_flat_path( nested_state = traversals.unflatten_mapping(flat_state) return cls(nested_state) + def to_pure_dict(self, + extract_fn: ExtractValueFn | None = None + ) -> dict[str, tp.Any]: + # Works for nnx.Variable and nnx.VariableState + if extract_fn is None: + extract_fn = lambda x: x.value if hasattr(x, 'value') else x + flat_values = {k: extract_fn(x) for k, x in self.flat_state().items()} + return traversals.unflatten_mapping(flat_values) + + def replace_by_pure_dict(self, + pure_dict: dict[str, tp.Any], + replace_fn: SetValueFn | None = None): + # Works for nnx.Variable and nnx.VariableState + if replace_fn is None: + replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v + current_flat = self.flat_state() + for kp, v in traversals.flatten_mapping(pure_dict).items(): + if kp not in current_flat: + raise ValueError(f'key in pure_dict not available in state: {kp}') + current_flat[kp] = replace_fn(current_flat[kp], v) + self.update(traversals.unflatten_mapping(current_flat)) + @tp.overload def split(self, first: filterlib.Filter, /) -> State[K, V]: ... diff --git a/tests/nnx/state_test.py b/tests/nnx/state_test.py index e1884134ff..3cfde22e0d 100644 --- a/tests/nnx/state_test.py +++ b/tests/nnx/state_test.py @@ -15,6 +15,8 @@ from absl.testing import absltest from flax import nnx +import jax +from jax import numpy as jnp class StateTest(absltest.TestCase): @@ -75,6 +77,21 @@ def __init__(self, *, rngs: nnx.Rngs): assert module.layers[1].kernel.value.shape == (2, 3) assert state.layers[1].kernel.value.shape == (2, 3) + def test_pure_dict(self): + module = nnx.Linear(4, 5, rngs=nnx.Rngs(0)) + state = nnx.state(module) + pure_dict = state.to_pure_dict() + assert isinstance(pure_dict, dict) + assert isinstance(pure_dict['kernel'], jax.Array) + assert isinstance(pure_dict['bias'], jax.Array) + state.replace_by_pure_dict(jax.tree.map(jnp.zeros_like, pure_dict)) + assert isinstance(state, nnx.State) + assert isinstance(state.kernel, nnx.VariableState) + assert jnp.array_equal(state.kernel.value, jnp.zeros((4, 5))) + assert state.kernel.type == nnx.Param + nnx.update(module, state) + assert jnp.array_equal(module(jnp.ones((3, 4))), jnp.zeros((3, 5))) + if __name__ == '__main__': absltest.main() \ No newline at end of file