Skip to content

Commit

Permalink
On top of #3217, make sure the shape compatibility works before the n…
Browse files Browse the repository at this point in the history
…ext Orbax release.

PiperOrigin-RevId: 551294857
  • Loading branch information
IvyZX authored and Flax Authors committed Jul 26, 2023
1 parent f103233 commit 783fb92
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
12 changes: 11 additions & 1 deletion flax/training/orbax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated."""

import dataclasses
import inspect
from typing import Any, Optional
import warnings
Expand Down Expand Up @@ -107,6 +108,15 @@ def substitute_embedding(s):
return jax.sharding.NamedSharding(mesh, s.spec)

sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree)
return ocp.checkpoint_utils.construct_restore_args(
restore_args = ocp.checkpoint_utils.construct_restore_args(
target, sharding_tree, **ocp_kwargs
)
# TODO(ivyzheng): remove after Orbax new release.
if not ocp_kwargs:
restore_args = jax.tree_util.tree_map(
lambda ra: dataclasses.replace(ra, global_shape=None)
if isinstance(ra, ocp.ArrayRestoreArgs)
else ra,
restore_args,
)
return restore_args
11 changes: 11 additions & 0 deletions tests/checkpoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,17 @@ def test_auto_restore(self):
)
check_eq(restored, to_save)

@parameterized.parameters({'use_orbax': True}, {'use_orbax': False})
def test_smaller_target(self, use_orbax):
config.update('flax_use_orbax_checkpointing', use_orbax)
tmp_dir = self.create_tempdir().full_path
to_save = {'a': jnp.ones((16, 256, 1024))}
target = {'a': jnp.zeros((2, 3))}

checkpoints.save_checkpoint(tmp_dir, to_save, 0, keep=1)
new_object = checkpoints.restore_checkpoint(tmp_dir, target)
check_eq(new_object, to_save)

def test_convert_pre_linen(self):
params = checkpoints.convert_pre_linen({
'mod_0': {
Expand Down

0 comments on commit 783fb92

Please sign in to comment.