From 463ad6cbf8b34555df0355b96e0aea0257b8b441 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Wed, 13 Sep 2023 09:11:17 -0700 Subject: [PATCH] Turn OCDBT on by default PiperOrigin-RevId: 565069275 --- flax/training/checkpoints.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 84d8bfe4cf..858705c4e2 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -87,6 +87,7 @@ # Orbax main checkpoint file name. ORBAX_CKPT_FILENAME = 'checkpoint' +ORBAX_MANIFEST_OCDBT = 'manifest.odbt' PyTree = Any @@ -131,6 +132,12 @@ def _safe_remove(path: str): io.remove(path) +def _is_ocdbt_checkpoint(path: str) -> bool: + return io.exists(os.path.join(path, ORBAX_CKPT_FILENAME)) or io.exists( + os.path.join(path, ORBAX_MANIFEST_OCDBT) + ) + + class AsyncManager: """A simple object to track async checkpointing. @@ -1046,7 +1053,7 @@ def restore_checkpoint( return target if io.isdir(ckpt_dir): # This means the given dir is an orbax checkpoint. - if io.exists(os.path.join(ckpt_dir, ORBAX_CKPT_FILENAME)): + if _is_ocdbt_checkpoint(ckpt_dir): ckpt_path = ckpt_dir else: ckpt_path = latest_checkpoint(ckpt_dir, prefix) # type: ignore @@ -1059,7 +1066,7 @@ def restore_checkpoint( ckpt_path = ckpt_dir # Restore the checkpoint with Orbax if needed. - is_orbax = io.exists(os.path.join(ckpt_path, ORBAX_CKPT_FILENAME)) + is_orbax = _is_ocdbt_checkpoint(ckpt_path) ckpt_type = 'orbax' if is_orbax else 'legacy Flax' logging.info(f'Restoring {ckpt_type} checkpoint from {ckpt_path}') if is_orbax: