Skip to content

Commit

Permalink
Let Flax-Orbax to not port the shape of target arrays when they por…
Browse files Browse the repository at this point in the history
…t the `target` shardings.

This allow people to continue using Flax checkpointing API with target pytrees of desired sharding but smaller shapes, avoiding memory burdens.

No impact if user is using native Orbax.

PiperOrigin-RevId: 551289243
  • Loading branch information
IvyZX authored and Flax Authors committed Jul 26, 2023
1 parent 2497f82 commit 719217b
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions flax/training/orbax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated."""

import inspect
from typing import Any, Optional
import warnings

Expand Down Expand Up @@ -68,19 +69,30 @@ def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any:
"""

def find_sharding(x):
if is_multi_device_array(x):
if isinstance(x, jax.Array):
return x.sharding
return None

# Simpler case: no multihost arrays
# Simpler case: no JAX arrays
if not any(
jax.tree_util.tree_flatten(jax.tree_map(is_multi_device_array, target))[0]
jax.tree_util.tree_flatten(jax.tree_map(find_sharding, target))[0]
):
return jax.tree_util.tree_map(
lambda x: ocp.RestoreArgs(restore_type=np.ndarray), target
)

# Multihost arrays: find sharding from the given target
# JAX arrays: find sharding from the given target and create RestoreArgs

# TODO(ivyzheng): remove after Orbax new release.
ocp_kwargs = {}
if (
'set_global_shape'
in inspect.signature(
ocp.checkpoint_utils.construct_restore_args
).parameters
):
ocp_kwargs['set_global_shape'] = False

sharding_tree = jax.tree_util.tree_map(find_sharding, target)
if mesh is not None:
warnings.warn(
Expand All @@ -90,8 +102,11 @@ def find_sharding(x):
),
DeprecationWarning,
)
axes_tree = jax.tree_util.tree_map(lambda s: s.spec, sharding_tree)
return ocp.checkpoint_utils.restore_args_from_target(
mesh, target, axes_tree
)
return ocp.checkpoint_utils.construct_restore_args(target, sharding_tree)

def substitute_embedding(s):
return jax.sharding.NamedSharding(mesh, s.spec)

sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree)
return ocp.checkpoint_utils.construct_restore_args(
target, sharding_tree, **ocp_kwargs
)

0 comments on commit 719217b

Please sign in to comment.