diff --git a/flax/training/orbax_utils.py b/flax/training/orbax_utils.py index be2ccc48da..87ba6cc14b 100644 --- a/flax/training/orbax_utils.py +++ b/flax/training/orbax_utils.py @@ -14,6 +14,7 @@ """Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated.""" +import dataclasses import inspect from typing import Any, Optional import warnings @@ -107,6 +108,15 @@ def substitute_embedding(s): return jax.sharding.NamedSharding(mesh, s.spec) sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree) - return ocp.checkpoint_utils.construct_restore_args( + restore_args = ocp.checkpoint_utils.construct_restore_args( target, sharding_tree, **ocp_kwargs ) + # TODO(ivyzheng): remove after Orbax new release. + if not ocp_kwargs: + restore_args = jax.tree_util.tree_map( + lambda ra: dataclasses.replace(ra, global_shape=None) + if isinstance(ra, ocp.ArrayRestoreArgs) + else ra, + restore_args, + ) + return restore_args