diff --git a/flax/training/orbax_utils.py b/flax/training/orbax_utils.py index 19399cd333..be2ccc48da 100644 --- a/flax/training/orbax_utils.py +++ b/flax/training/orbax_utils.py @@ -14,6 +14,7 @@ """Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated.""" +import inspect from typing import Any, Optional import warnings @@ -68,19 +69,30 @@ def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any: """ def find_sharding(x): - if is_multi_device_array(x): + if isinstance(x, jax.Array): return x.sharding return None - # Simpler case: no multihost arrays + # Simpler case: no JAX arrays if not any( - jax.tree_util.tree_flatten(jax.tree_map(is_multi_device_array, target))[0] + jax.tree_util.tree_flatten(jax.tree_map(find_sharding, target))[0] ): return jax.tree_util.tree_map( lambda x: ocp.RestoreArgs(restore_type=np.ndarray), target ) - # Multihost arrays: find sharding from the given target + # JAX arrays: find sharding from the given target and create RestoreArgs + + # TODO(ivyzheng): remove after Orbax new release. + ocp_kwargs = {} + if ( + 'set_global_shape' + in inspect.signature( + ocp.checkpoint_utils.construct_restore_args + ).parameters + ): + ocp_kwargs['set_global_shape'] = False + sharding_tree = jax.tree_util.tree_map(find_sharding, target) if mesh is not None: warnings.warn( @@ -90,8 +102,11 @@ def find_sharding(x): ), DeprecationWarning, ) - axes_tree = jax.tree_util.tree_map(lambda s: s.spec, sharding_tree) - return ocp.checkpoint_utils.restore_args_from_target( - mesh, target, axes_tree - ) - return ocp.checkpoint_utils.construct_restore_args(target, sharding_tree) + + def substitute_embedding(s): + return jax.sharding.NamedSharding(mesh, s.spec) + + sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree) + return ocp.checkpoint_utils.construct_restore_args( + target, sharding_tree, **ocp_kwargs + )