From 02f093230d41e9265ee6c8f257b4f3ec74595548 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 13 Nov 2024 14:21:40 -0800 Subject: [PATCH] Configure no restart validation loop in nl.Trainer (#11029) * Configure no restart validation loop in nl.Trainer Signed-off-by: Hemil Desai * fix Signed-off-by: Hemil Desai * Skip validation whenever restarting=True Signed-off-by: Hemil Desai * PR feedback Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai --------- Signed-off-by: Hemil Desai Signed-off-by: hemildesai Co-authored-by: hemildesai --- nemo/collections/llm/api.py | 10 ++++++++- nemo/lightning/__init__.py | 3 ++- nemo/lightning/pytorch/trainer.py | 37 ++++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 13f25eb21087..fdceff5d959e 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -25,7 +25,14 @@ from typing_extensions import Annotated import nemo.lightning as nl -from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io +from nemo.lightning import ( + AutoResume, + NeMoLogger, + OptimizerModule, + Trainer, + configure_no_restart_validation_training_loop, + io, +) from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging @@ -680,6 +687,7 @@ def _setup( tokenizer: Optional[TokenizerType], model_transform: Optional[Union[PEFT, ModelTransform, Callable]], ) -> Any: # Return type is Any because app_state's type is not specified + configure_no_restart_validation_training_loop(trainer) _log = log or NeMoLogger() if resume and isinstance(model_transform, PEFT) and _log.ckpt: logging.info("Disabling try_restore_best_ckpt restoration for adapters") diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 2cc720e148d4..91d3b3f936d0 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -33,7 +33,7 @@ from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig -from nemo.lightning.pytorch.trainer import Trainer +from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop from nemo.lightning.resume import AutoResume @@ -66,6 +66,7 @@ def _is_slurm_interactive_mode(): "ModelCheckpoint", "OptimizerModule", "Trainer", + "configure_no_restart_validation_training_loop", "get_vocab_size", "teardown", ] diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 0d71c49bf198..c97c59ef524d 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy import fiddle as fdl import pytorch_lightning as pl +from pytorch_lightning.loops import _TrainingEpochLoop +from pytorch_lightning.loops.fetchers import _DataFetcher from typing_extensions import Self from nemo.lightning.fabric.conversion import to_fabric @@ -23,8 +26,40 @@ from nemo.lightning.io.mixin import IOMixin, serialization, track_io -class Trainer(pl.Trainer, IOMixin): +class NoValOnRestartTrainingLoop(_TrainingEpochLoop): + """ + Extend the PTL Epoch loop to skip validation when restarting. + This happens when resuming a checkpoint that has already run validation, but loading restores + the training state before validation has run. + """ + + def _should_check_val_fx(self, data_fetcher) -> bool: + if self.skip_val_on_restart: + return False + return super()._should_check_val_fx(data_fetcher) + + def load_state_dict(self, state_dict: dict, prefix: str = "") -> None: + super().load_state_dict(state_dict, prefix) + + self.skip_val_on_restart = True + + def advance(self, data_fetcher: _DataFetcher) -> None: + super().advance(data_fetcher) + + self.skip_val_on_restart = False + +def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None: + if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop): + warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) + return + + ## Pass trainer object to avoid trainer getting overwritten as None + loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps) + trainer.fit_loop.epoch_loop = loop + + +class Trainer(pl.Trainer, IOMixin): def add_io(self, obj): """Recurse to the leaves of a container and add io functionality to non-serializable leaves""" if isinstance(obj, (dict, list)):