Skip to content

Commit

Permalink
Merge pull request #4317 from google:improve-state-loading
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688513764
  • Loading branch information
Flax Authors committed Oct 22, 2024
2 parents 092ba1d + 059743c commit 8360b7c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
6 changes: 6 additions & 0 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,17 @@ def to_pure_dict(self,
def replace_by_pure_dict(self,
pure_dict: dict[str, tp.Any],
replace_fn: SetValueFn | None = None):
def try_convert_int(x):
try:
return int(x)
except ValueError:
return x
# Works for nnx.Variable and nnx.VariableState
if replace_fn is None:
replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v
current_flat = self.flat_state()
for kp, v in traversals.flatten_mapping(pure_dict).items():
kp = tuple(map(try_convert_int, kp))
if kp not in current_flat:
raise ValueError(f'key in pure_dict not available in state: {kp}')
current_flat[kp] = replace_fn(current_flat[kp], v)
Expand Down
40 changes: 40 additions & 0 deletions tests/nnx/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import typing as tp

from absl.testing import absltest
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp

from flax import nnx

Expand Down Expand Up @@ -259,6 +261,44 @@ def __call__(self, x):

assert 'y' in intermediates

def test_replace_by_pure_dict(self):
class MLPs(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.layers = []
for _ in range(4):
self.layers.append(nnx.Linear(dim, dim, rngs=rngs, use_bias=False))

def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x

model = MLPs(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)

_, state = nnx.split(model)
pure_dict_state = state.to_pure_dict()
nnx.display(pure_dict_state)

with tempfile.TemporaryDirectory() as tmpdir:
ckpt_dir = ocp.test_utils.erase_and_create_empty(
tmpdir + '/my-checkpoints/'
)
checkpointer = ocp.StandardCheckpointer()
# checkpointer.save(ckpt_dir / 'state', state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)

# Restore as a pure dictionary.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
nnx.display(restored_pure_dict)

abstract_model = nnx.eval_shape(lambda: MLPs(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!


if __name__ == '__main__':
absltest.main()

0 comments on commit 8360b7c

Please sign in to comment.