Skip to content

Commit

Permalink
Turn OCDBT on by default
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565069275
  • Loading branch information
ChromeHearts authored and Flax Authors committed Sep 15, 2023
1 parent 654ae1a commit 463ad6c
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@

# Orbax main checkpoint file name.
ORBAX_CKPT_FILENAME = 'checkpoint'
ORBAX_MANIFEST_OCDBT = 'manifest.odbt'

PyTree = Any

Expand Down Expand Up @@ -131,6 +132,12 @@ def _safe_remove(path: str):
io.remove(path)


def _is_ocdbt_checkpoint(path: str) -> bool:
return io.exists(os.path.join(path, ORBAX_CKPT_FILENAME)) or io.exists(
os.path.join(path, ORBAX_MANIFEST_OCDBT)
)


class AsyncManager:
"""A simple object to track async checkpointing.
Expand Down Expand Up @@ -1046,7 +1053,7 @@ def restore_checkpoint(
return target
if io.isdir(ckpt_dir):
# This means the given dir is an orbax checkpoint.
if io.exists(os.path.join(ckpt_dir, ORBAX_CKPT_FILENAME)):
if _is_ocdbt_checkpoint(ckpt_dir):
ckpt_path = ckpt_dir
else:
ckpt_path = latest_checkpoint(ckpt_dir, prefix) # type: ignore
Expand All @@ -1059,7 +1066,7 @@ def restore_checkpoint(
ckpt_path = ckpt_dir

# Restore the checkpoint with Orbax if needed.
is_orbax = io.exists(os.path.join(ckpt_path, ORBAX_CKPT_FILENAME))
is_orbax = _is_ocdbt_checkpoint(ckpt_path)
ckpt_type = 'orbax' if is_orbax else 'legacy Flax'
logging.info(f'Restoring {ckpt_type} checkpoint from {ckpt_path}')
if is_orbax:
Expand Down

0 comments on commit 463ad6c

Please sign in to comment.