Skip to content

Commit

Permalink
Merge branch 'main' into jiemingz/ckpt_mem_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ericharper authored Mar 23, 2024
2 parents 7fbf9e2 + 11b7a73 commit 55ca157
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def setup(self, stage=None):

if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False):
self.setup_transformer_engine_tp_groups()
self.setup_transformer_engine_cp_groups()
self.setup_complete = True

def _build_dataset(self, data_cfg, is_train=True):
Expand Down
28 changes: 27 additions & 1 deletion nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol
from pytorch_lightning.utilities import rank_zero_info

from nemo.collections.common.callbacks import EMA
Expand Down Expand Up @@ -454,3 +454,29 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
# delete markers
for marker_path in existing_marker_filepaths:
os.remove(marker_path)

def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
"""Checks if the previous checkpoint should be deleted.
A checkpoint won't be deleted if any of the cases apply:
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
and the resumed from checkpoint is not the last checkpoint
"""
if previous == current:
return False
if not _is_local_file_protocol(previous):
return True
previous = Path(previous).absolute()
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None

if resume_path is not None and previous == resume_path:
if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"):
# delete the previous `-last.ckpt` checkpoint when current saved checkpoint is also `-last.ckpt`, if they're in the same directory
pass
else:
return False
if self.dirpath is None:
raise ValueError(f"{self.__class__}.dirpath is None.")
dirpath = Path(self.dirpath).absolute()
return dirpath in previous.parents
5 changes: 2 additions & 3 deletions tests/core/test_exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,8 @@ def test_invalid_checkpoints_removed_from_topk(self, tmp_path):
test_trainer2.fit(model)

ckpt_filenames = {f.name for f in checkpoints_dir.rglob("*.ckpt") if f.is_file()}
# 3 top + 1 last + 1 resume ckpt since PTL >= 2.1 ensures to never delete the resume ckpt
# (https://github.com/Lightning-AI/pytorch-lightning/pull/18750)
assert len(ckpt_filenames) == 5
# 3 top + 1 last
assert len(ckpt_filenames) == 4
assert 'epoch=9-last.ckpt' in ckpt_filenames
assert 'epoch=8.ckpt' in ckpt_filenames
assert 'epoch=7.ckpt' in ckpt_filenames
Expand Down

0 comments on commit 55ca157

Please sign in to comment.