diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 59461fc46b..3e27937d36 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -179,11 +179,17 @@ def to_pure_dict(self, def replace_by_pure_dict(self, pure_dict: dict[str, tp.Any], replace_fn: SetValueFn | None = None): + def try_convert_int(x): + try: + return int(x) + except ValueError: + return x # Works for nnx.Variable and nnx.VariableState if replace_fn is None: replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v current_flat = self.flat_state() for kp, v in traversals.flatten_mapping(pure_dict).items(): + kp = tuple(map(try_convert_int, kp)) if kp not in current_flat: raise ValueError(f'key in pure_dict not available in state: {kp}') current_flat[kp] = replace_fn(current_flat[kp], v) diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 1742e379cb..7b572f4b18 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import typing as tp from absl.testing import absltest import jax import jax.numpy as jnp import numpy as np +import orbax.checkpoint as ocp from flax import nnx @@ -259,6 +261,44 @@ def __call__(self, x): assert 'y' in intermediates + def test_replace_by_pure_dict(self): + class MLPs(nnx.Module): + def __init__(self, dim, rngs: nnx.Rngs): + self.layers = [] + for _ in range(4): + self.layers.append(nnx.Linear(dim, dim, rngs=rngs, use_bias=False)) + + def __call__(self, x): + for layer in self.layers: + x = layer(x) + return x + + model = MLPs(4, rngs=nnx.Rngs(0)) + x = jax.random.normal(jax.random.key(42), (3, 4)) + assert model(x).shape == (3, 4) + + _, state = nnx.split(model) + pure_dict_state = state.to_pure_dict() + nnx.display(pure_dict_state) + + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_dir = ocp.test_utils.erase_and_create_empty( + tmpdir + '/my-checkpoints/' + ) + checkpointer = ocp.StandardCheckpointer() + # checkpointer.save(ckpt_dir / 'state', state) + checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state) + + # Restore as a pure dictionary. + restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') + nnx.display(restored_pure_dict) + + abstract_model = nnx.eval_shape(lambda: MLPs(4, rngs=nnx.Rngs(0))) + graphdef, abstract_state = nnx.split(abstract_model) + abstract_state.replace_by_pure_dict(restored_pure_dict) + model = nnx.merge(graphdef, abstract_state) + assert model(x).shape == (3, 4) # The model still works! + if __name__ == '__main__': absltest.main()