From 3741278de6427e61f93b7958ac20b5d224c1886b Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Wed, 26 Jul 2023 16:54:41 -0700 Subject: [PATCH] On top of #3217, make sure the shape compatibility works before the next Orbax release. PiperOrigin-RevId: 551354501 --- flax/training/orbax_utils.py | 12 +++++++++++- tests/checkpoints_test.py | 11 +++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/flax/training/orbax_utils.py b/flax/training/orbax_utils.py index be2ccc48da..87ba6cc14b 100644 --- a/flax/training/orbax_utils.py +++ b/flax/training/orbax_utils.py @@ -14,6 +14,7 @@ """Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated.""" +import dataclasses import inspect from typing import Any, Optional import warnings @@ -107,6 +108,15 @@ def substitute_embedding(s): return jax.sharding.NamedSharding(mesh, s.spec) sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree) - return ocp.checkpoint_utils.construct_restore_args( + restore_args = ocp.checkpoint_utils.construct_restore_args( target, sharding_tree, **ocp_kwargs ) + # TODO(ivyzheng): remove after Orbax new release. + if not ocp_kwargs: + restore_args = jax.tree_util.tree_map( + lambda ra: dataclasses.replace(ra, global_shape=None) + if isinstance(ra, ocp.ArrayRestoreArgs) + else ra, + restore_args, + ) + return restore_args diff --git a/tests/checkpoints_test.py b/tests/checkpoints_test.py index d273c3e051..b49d715561 100644 --- a/tests/checkpoints_test.py +++ b/tests/checkpoints_test.py @@ -461,6 +461,17 @@ def test_auto_restore(self): ) check_eq(restored, to_save) + @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) + def test_smaller_target(self, use_orbax): + config.update('flax_use_orbax_checkpointing', use_orbax) + tmp_dir = self.create_tempdir().full_path + to_save = {'a': jnp.ones((16, 256, 1024))} + target = {'a': jnp.zeros((2, 3))} + + checkpoints.save_checkpoint(tmp_dir, to_save, 0, keep=1) + new_object = checkpoints.restore_checkpoint(tmp_dir, target) + check_eq(new_object, to_save) + def test_convert_pre_linen(self): params = checkpoints.convert_pre_linen({ 'mod_0': {