Skip to content

Commit

Permalink
On top of #3217, make sure the shape compatibility works before the n…
Browse files Browse the repository at this point in the history
…ext Orbax release.

PiperOrigin-RevId: 551294857
  • Loading branch information
IvyZX authored and Flax Authors committed Jul 26, 2023
1 parent 719217b commit 1a77965
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion flax/training/orbax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated."""

import dataclasses
import inspect
from typing import Any, Optional
import warnings
Expand Down Expand Up @@ -107,6 +108,15 @@ def substitute_embedding(s):
return jax.sharding.NamedSharding(mesh, s.spec)

sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree)
return ocp.checkpoint_utils.construct_restore_args(
restore_args = ocp.checkpoint_utils.construct_restore_args(
target, sharding_tree, **ocp_kwargs
)
# TODO(ivyzheng): remove after Orbax new release.
if not ocp_kwargs:
restore_args = jax.tree_util.tree_map(
lambda ra: dataclasses.replace(ra, global_shape=None)
if isinstance(ra, ocp.ArrayRestoreArgs)
else ra,
restore_args,
)
return restore_args

0 comments on commit 1a77965

Please sign in to comment.