Skip to content

Commit

Permalink
Remove import checks that are no longer valid.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572721758
  • Loading branch information
IvyZX authored and Flax Authors committed Oct 12, 2023
1 parent d0d0439 commit 23eb4e5
Showing 1 changed file with 19 additions and 34 deletions.
53 changes: 19 additions & 34 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,13 @@
from jax import monitoring
from jax import process_index
from jax import tree_util as jtu
from jax.experimental.array_serialization.serialization import get_tensorstore_spec
from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager
from jax.experimental.multihost_utils import sync_global_devices
import numpy as np
import orbax.checkpoint as ocp

_READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec'
_WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec'
_IMPORT_GDAM_SUCCESSFUL = False
try:
from jax.experimental.array_serialization.serialization import get_tensorstore_spec
from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager

_IMPORT_GDAM_SUCCESSFUL = True
except ImportError:
logging.warning(
'GlobalAsyncCheckpointManager is not imported correctly. '
'Checkpointing of GlobalDeviceArrays will not be available.'
'To use the feature, install tensorstore.'
)


# Single-group reg-exps for int or float numerical substrings.
Expand Down Expand Up @@ -262,7 +251,7 @@ def _restore_mpas(
target: Optional[Any],
ckpt_path: str,
step: Optional[Union[int, float]],
gda_manager: Optional[Any],
gda_manager: Optional[GlobalAsyncCheckpointManager],
allow_partial: bool = False,
):
"""Restore the multiprocess arrays given the target structure and type."""
Expand Down Expand Up @@ -740,7 +729,7 @@ def save_checkpoint_multiprocess(
overwrite: bool = False,
keep_every_n_steps: Optional[int] = None,
async_manager: Optional[AsyncManager] = None,
gda_manager: Optional[Any] = None,
gda_manager: Optional[GlobalAsyncCheckpointManager] = None,
orbax_checkpointer: Optional[ocp.Checkpointer] = None,
) -> str:
"""Save a checkpoint of the model in multi-process environment.
Expand Down Expand Up @@ -768,11 +757,9 @@ def save_checkpoint_multiprocess(
async_manager: if defined, the save will run without blocking the main
thread. Only works for single host. Note that an ongoing save will still
block subsequent saves, to make sure overwrite/keep logic works correctly.
gda_manager: required if target contains a JAX GlobalDeviceArray. Type
should be GlobalAsyncCheckpointManager (needs Tensorstore to be imported
correctly). Will save the GDAs to a separate subdirectory with postfix
"_gda" asynchronously. Same as async_manager, this will block subsequent
saves.
gda_manager: required if target contains a JAX GlobalDeviceArray. Will save
the GDAs to a separate subdirectory with postfix "_gda" asynchronously.
Same as async_manager, this will block subsequent saves.
orbax_checkpointer: if defined, the save will be done by Orbax In the
future, all Flax checkpointing features will be migrated to Orbax,
and starting to use an `orbax_checkpointer` is recommended. Please
Expand Down Expand Up @@ -850,7 +837,7 @@ def save_checkpoint_multiprocess(
target = serialization.to_state_dict(target)
target, mpa_targets = _split_mp_arrays(target)
target = serialization.msgpack_serialize(target)
has_mpa = mpa_targets and _IMPORT_GDAM_SUCCESSFUL
has_mpa = mpa_targets

if not overwrite:
_check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore
Expand Down Expand Up @@ -989,7 +976,7 @@ def restore_checkpoint(
step: Optional[Union[int, float]] = None,
prefix: str = 'checkpoint_',
parallel: bool = True,
gda_manager: Optional[Any] = None,
gda_manager: Optional[GlobalAsyncCheckpointManager] = None,
allow_partial_mpa_restoration: bool = False,
orbax_checkpointer: Optional[ocp.Checkpointer] = None,
orbax_transforms: Optional[Dict] = None,
Expand All @@ -1014,9 +1001,8 @@ def restore_checkpoint(
prefix: str: name prefix of checkpoint files.
parallel: bool: whether to load seekable checkpoints in parallel, for speed.
gda_manager: required if checkpoint contains a multiprocess array
(GlobalDeviceArray or jax Array from pjit). Type should be
GlobalAsyncCheckpointManager (needs Tensorstore to be imported correctly).
Will read the arrays from the separate subdirectory with postfix "_gda".
(GlobalDeviceArray or jax Array from pjit). Will read the arrays from the
separate subdirectory with postfix "_gda".
allow_partial_mpa_restoration: If true, the given `target` doesn't have to
contain all valid multiprocess arrays. As a result, the restored Pytree
may have some MPAs not restored correctly. Use this if you cannot provide
Expand Down Expand Up @@ -1126,15 +1112,14 @@ def read_chunk(i):
checkpoint_contents = fp.read()

state_dict = serialization.msgpack_restore(checkpoint_contents)
if _IMPORT_GDAM_SUCCESSFUL:
state_dict = _restore_mpas(
state_dict,
target,
ckpt_path,
step,
gda_manager,
allow_partial_mpa_restoration,
)
state_dict = _restore_mpas(
state_dict,
target,
ckpt_path,
step,
gda_manager,
allow_partial_mpa_restoration,
)

if target is None:
restored_checkpoint = state_dict
Expand Down

0 comments on commit 23eb4e5

Please sign in to comment.