diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 89c7cdb17aef..5d1b3bede468 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -18,7 +18,7 @@ import shutil from collections import OrderedDict from contextlib import ExitStack, contextmanager -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from pathlib import Path from typing import ( TYPE_CHECKING, @@ -40,10 +40,9 @@ import torch.distributed from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device +from megatron.core.dist_checkpointing.core import maybe_load_config from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from megatron.core.dist_checkpointing.core import maybe_load_config - from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher @@ -674,7 +673,9 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore: and self.trainer.state.fn == TrainerFn.FITTING ): if self.lightning_module.optimizers(use_pl_optimizer=False): - sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True, metadata=metadata)] + sharded_state_dict["optimizer"] = [ + self.optimizer_sharded_state_dict(is_loading=True, metadata=metadata) + ] checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict)