Skip to content

Commit

Permalink
Add merge / finalize step when using OCDBT driver. Files will be firs…
Browse files Browse the repository at this point in the history
…t written to per-process subdirectories, which are later copied by reference to the main directory before the checkpoint is finalized.

PiperOrigin-RevId: 572392273
  • Loading branch information
cpgaffney1 authored and Flax Authors committed Oct 20, 2023
1 parent abab11f commit 1a420bd
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/checkpoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,14 @@ def test_overwrite_checkpoints(self, use_orbax):

os.chdir(os.path.dirname(tmp_dir))
rel_tmp_dir = './' + os.path.basename(tmp_dir)
checkpoints.save_checkpoint(rel_tmp_dir, test_object, 3, keep=1)
new_object = checkpoints.restore_checkpoint(rel_tmp_dir, test_object0)
check_eq(new_object, test_object)
if use_orbax:
with self.assertRaises(ValueError):
checkpoints.save_checkpoint(rel_tmp_dir, test_object, 3, keep=1)
else:
checkpoints.save_checkpoint(rel_tmp_dir, test_object, 3, keep=1)
new_object = checkpoints.restore_checkpoint(rel_tmp_dir, test_object0)
check_eq(new_object, test_object)

non_norm_dir_path = tmp_dir + '//'
checkpoints.save_checkpoint(non_norm_dir_path, test_object, 4, keep=1)
new_object = checkpoints.restore_checkpoint(non_norm_dir_path, test_object0)
Expand Down

0 comments on commit 1a420bd

Please sign in to comment.