diff --git a/.github/workflows/_test_template.yml b/.github/workflows/_test_template.yml index bd80e88ee964..4e186ca29028 100644 --- a/.github/workflows/_test_template.yml +++ b/.github/workflows/_test_template.yml @@ -59,7 +59,7 @@ jobs: ( set -e - docker run --rm --device=/dev/nvidia0 --gpus all --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.SCRIPT }}' + docker run --rm --runtime=nvidia --gpus all --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.SCRIPT }}' ) 2> >(tee err.log) EXIT_CODE=$? @@ -73,4 +73,4 @@ jobs: - name: after_script if: always() && inputs.AFTER_SCRIPT != ':' run: | - docker run --rm --device=/dev/nvidia0 --gpus all --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.AFTER_SCRIPT }}' \ No newline at end of file + docker run --rm --runtime=nvidia --gpus all --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.AFTER_SCRIPT }}' \ No newline at end of file diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 20c0df66c005..855cfc75db44 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -112,6 +112,7 @@ jobs: # Basic Import Checks python -c "import nemo.collections.asr as nemo_asr" python -c "import nemo.collections.nlp as nemo_nlp" + python -c "import nemo.collections.nlp as nemo_nlp; nemo_nlp.modules.get_tokenizer_list()" python -c "import nemo.collections.tts as nemo_tts" python setup.py style @@ -5245,13 +5246,19 @@ jobs: NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_pretraining.py \ --devices=2 \ --max-steps=3 \ - --experiment-dir=tests/collections/llm/t5_pretrain_results \ + --experiment-dir=tests/collections/llm/t5_pretrain_results/${{ github.run_id }} \ --data-path=/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document \ - --index-mapping-dir=tests/collections/llm/t5_index_mappings + --index-mapping-dir=tests/collections/llm/t5_index_mappings/${{ github.run_id }} + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_pretraining.py \ + --devices=2 \ + --max-steps=6 \ + --experiment-dir=tests/collections/llm/t5_pretrain_results/${{ github.run_id }} \ + --data-path=/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document \ + --index-mapping-dir=tests/collections/llm/t5_index_mappings/${{ github.run_id }} AFTER_SCRIPT: | - rm -rf tests/collections/llm/t5_pretrain_results - rm -rf tests/collections/llm/t5_index_mappings + rm -rf tests/collections/llm/t5_pretrain_results/${{ github.run_id }} + rm -rf tests/collections/llm/t5_index_mappings/${{ github.run_id }} Nemo_CICD_Test: needs: @@ -5402,7 +5409,7 @@ jobs: echo "FAILED=$FAILED" >> $GITHUB_OUTPUT # Mark as successful if no job was cancelled: - SUCCESS=${{ !contains(needs.*.result, 'cancelled') }} + SUCCESS=${{ !contains(needs.*.result, 'cancelled') && !contains(needs.*.result, 'skipped') }} echo "SUCCESS=$SUCCESS" >> $GITHUB_OUTPUT # This should depend on all the tests so we block/unblock based on all tests passing diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index 446e40714460..eb905d3e91b0 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -82,6 +82,7 @@ from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel +from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig @@ -93,7 +94,9 @@ @dataclass class ParallelTranscriptionConfig: model: Optional[str] = None # name - predict_ds: ASRDatasetConfig = ASRDatasetConfig(return_sample_id=True, num_workers=4) + predict_ds: ASRDatasetConfig = ASRDatasetConfig( + return_sample_id=True, num_workers=4, min_duration=0, max_duration=40 + ) output_path: str = MISSING # when return_predictions is enabled, the prediction call would keep all the predictions in memory and return them when prediction is done @@ -157,10 +160,24 @@ def main(cfg: ParallelTranscriptionConfig): if isinstance(model, EncDecHybridRNNTCTCModel) and cfg.decoder_type is not None: model.change_decoding_strategy(decoder_type=cfg.decoder_type) - trainer = ptl.Trainer(**cfg.trainer) - cfg.predict_ds.return_sample_id = True cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) + + if isinstance(model, EncDecMultiTaskModel): + cfg.trainer.use_distributed_sampler = False + OmegaConf.set_struct(cfg.predict_ds, False) + cfg.predict_ds.use_lhotse = True + cfg.predict_ds.lang_field = "target_lang" + OmegaConf.set_struct(cfg.predict_ds, True) + + trainer = ptl.Trainer(**cfg.trainer) + + if isinstance(model, EncDecMultiTaskModel): + OmegaConf.set_struct(cfg.predict_ds, False) + cfg.predict_ds.global_rank = trainer.global_rank + cfg.predict_ds.world_size = trainer.world_size + OmegaConf.set_struct(cfg.predict_ds, True) + data_loader = model._setup_dataloader_from_config(cfg.predict_ds) os.makedirs(cfg.output_path, exist_ok=True) diff --git a/examples/nlp/language_modeling/megatron_gpt_distillation.py b/examples/nlp/language_modeling/megatron_gpt_distillation.py index b3ecdcfc5522..dc8614be23b2 100644 --- a/examples/nlp/language_modeling/megatron_gpt_distillation.py +++ b/examples/nlp/language_modeling/megatron_gpt_distillation.py @@ -33,12 +33,12 @@ from importlib.metadata import version from typing import Tuple +import packaging import torch import torch.nn.functional as F from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.parallel_state import get_tensor_model_parallel_group from megatron.core.transformer import TransformerConfig -from pkg_resources import packaging from torch import Tensor from torch.nn.modules.loss import _Loss diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 7ad6560b4401..c63c73323797 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -859,16 +859,30 @@ def write_on_batch_end( batch_idx: int, dataloader_idx: int, ): + import lhotse + for sample_id, transcribed_text in prediction: item = {} - sample = self.dataset.get_manifest_sample(sample_id) - item["audio_filepath"] = sample.audio_file - item["offset"] = sample.offset - item["duration"] = sample.duration - item["text"] = sample.text_raw - item["pred_text"] = transcribed_text - self.outf.write(json.dumps(item) + "\n") - self.samples_num += 1 + if isinstance(sample_id, lhotse.cut.Cut): + sample = sample_id + if isinstance(sample, lhotse.cut.MixedCut): + sample = sample.first_non_padding_cut + item["audio_filepath"] = sample.recording.sources[0].source + item["offset"] = sample.start + item["duration"] = sample.duration + item["text"] = sample.supervisions[0].text + item["pred_text"] = transcribed_text + self.outf.write(json.dumps(item) + "\n") + self.samples_num += 1 + else: + sample = self.dataset.get_manifest_sample(sample_id) + item["audio_filepath"] = sample.audio_file + item["offset"] = sample.offset + item["duration"] = sample.duration + item["text"] = sample.text_raw + item["pred_text"] = transcribed_text + self.outf.write(json.dumps(item) + "\n") + self.samples_num += 1 return def close_output_file(self): diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index f7c6b6adff7f..2dd4ae2980f1 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Callable, Union +from typing import Callable, Optional, Union import torch.utils.data from lhotse import CutSet @@ -32,6 +32,7 @@ class PromptedAudioToTextMiniBatch: prompt_lens: torch.Tensor prompted_transcript: torch.Tensor prompted_transcript_lens: torch.Tensor + cuts: Optional[CutSet] = None def get_decoder_inputs_outputs(self) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -98,6 +99,7 @@ def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: prompt_lens=prompt_lens, prompted_transcript=prompts_with_answers, prompted_transcript_lens=prompts_with_answers_lens, + cuts=cuts.drop_in_memory_data(), ) def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 61212c065058..aa2be082f821 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -504,10 +504,12 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): "Multi-task model only supports dataloading with Lhotse. " "Please set config.{train,validation,test}_ds.use_lhotse=True" ) + global_rank = config.get("global_rank", self.global_rank) + world_size = config.get("world_size", self.world_size) return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + global_rank=global_rank, + world_size=world_size, dataset=PromptedAudioToTextLhotseDataset( tokenizer=self.tokenizer, prompt_format_fn=get_prompt_format_fn(self.prompt_format), @@ -1042,8 +1044,7 @@ def predict_step( decoder_input_ids=batch.prompt, return_hypotheses=False, )[0] - - return text + return list(zip(batch.cuts, text)) @property def adapter_module_names(self) -> List[str]: diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 07bc4f3960d3..bb7598421c33 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -18,11 +18,11 @@ from typing import Any, Optional import numpy as np +import packaging import torch import torch.nn.functional as F from einops import rearrange, reduce, repeat from omegaconf import DictConfig, ListConfig, OmegaConf -from pkg_resources import packaging from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPVisionModel, SiglipVisionModel diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index b04ff248a326..e9fb1833fc08 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -15,8 +15,8 @@ from importlib.metadata import version from typing import Any, Callable, Optional +import packaging import torch -from pkg_resources import packaging from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults from nemo.collections.nlp.parts import utils_funcs diff --git a/nemo/collections/nlp/modules/common/lm_utils.py b/nemo/collections/nlp/modules/common/lm_utils.py index a6a5fd57892e..af6fc9ecb0a7 100644 --- a/nemo/collections/nlp/modules/common/lm_utils.py +++ b/nemo/collections/nlp/modules/common/lm_utils.py @@ -102,6 +102,16 @@ def get_lm_model( pretrain_model_name = '' if cfg.get('language_model') and cfg.language_model.get('pretrained_model_name', ''): pretrain_model_name = cfg.language_model.get('pretrained_model_name', '') + + from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel + + def get_megatron_pretrained_bert_models() -> List[str]: + + all_pretrained_megatron_bert_models = [ + model.pretrained_model_name for model in MegatronBertModel.list_available_models() + ] + return all_pretrained_megatron_bert_models + all_pretrained_megatron_bert_models = get_megatron_pretrained_bert_models() if ( cfg.tokenizer is not None @@ -175,13 +185,13 @@ def get_transformer( config_dict={ '_target_': 'transformers.BertConfig', 'hidden_size': 1536 - }) + }) Args: library (str, optional): Can be 'nemo', 'huggingface', or 'megatron'. Defaults to 'nemo'. model_name (Optional[str], optional): Named model architecture from the chosen library. Defaults to None. - pretrained (bool, optional): Use True to get pretrained weights. + pretrained (bool, optional): Use True to get pretrained weights. False will use the same architecture but with randomly initialized weights. Defaults to False. config_dict (Optional[dict], optional): Use for custom configuration of transformer. Defaults to None. diff --git a/nemo/collections/nlp/modules/common/megatron/__init__.py b/nemo/collections/nlp/modules/common/megatron/__init__.py index 4422a4e2bb78..2214c7f02059 100644 --- a/nemo/collections/nlp/modules/common/megatron/__init__.py +++ b/nemo/collections/nlp/modules/common/megatron/__init__.py @@ -12,7 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.nlp.modules.common.megatron.megatron_utils import ( - get_megatron_checkpoint, - get_megatron_lm_models_list, -) +from .megatron_utils import get_megatron_checkpoint, get_megatron_lm_models_list diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index a52607c01b7d..c1b4e3023e42 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -67,10 +67,9 @@ try: # Flash Attention Triton - import pkg_resources from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton -except (ImportError, ModuleNotFoundError, pkg_resources.DistributionNotFound): +except (ImportError, ModuleNotFoundError): flash_attn_func_triton = None @@ -202,7 +201,12 @@ def __init__( else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( - hidden_size, projection_size, config=config, gather_output=False, init_method=init_method, bias=bias, + hidden_size, + projection_size, + config=config, + gather_output=False, + init_method=init_method, + bias=bias, ) self.key_value = tensor_parallel.ColumnParallelLinear( @@ -336,7 +340,7 @@ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): """[s, b, num_splits * np * hn] -->(view) [s, b, num_splits, np, hn] -->(tranpose) [s, b, np, num_splits, hn] - -->(view) [s, b, np * num_splits * hn] """ + -->(view) [s, b, np * num_splits * hn]""" intermediate_shape = input_shape[:-1] + ( num_splits, @@ -350,7 +354,7 @@ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): """[s, b, np * hn * num_splits] -->(view) [s, b, np, hn, num_splits] -->(tranpose) [s, b, np, num_splits, hn] - -->(view) [s, b, np * num_splits * hn] """ + -->(view) [s, b, np * num_splits * hn]""" intermediate_shape = input_shape[:-1] + ( self.num_attention_heads_per_partition, @@ -535,7 +539,10 @@ def forward( ) v = _cast_if_autocast_enabled(rearrange(value_layer, 'sk b np hn -> b sk np hn')) context_layer = flash_attn_with_kvcache( - q=q, k_cache=k, v_cache=v, causal=self.attn_mask_type == AttnMaskType.causal, + q=q, + k_cache=k, + v_cache=v, + causal=self.attn_mask_type == AttnMaskType.causal, ) context_layer = rearrange(context_layer, 'b sq np hn -> sq b (np hn)') @@ -742,9 +749,9 @@ def forward( class CoreAttention(MegatronModule): - """ Region where selective activation recomputation is applied. - See Figure 3. in Reducing Activation Recomputation in Large Transformer Models - https://arxiv.org/pdf/2205.05198.pdf for more details. + """Region where selective activation recomputation is applied. + See Figure 3. in Reducing Activation Recomputation in Large Transformer Models + https://arxiv.org/pdf/2205.05198.pdf for more details. """ @@ -994,10 +1001,21 @@ def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, a if attention_bias is not None: return self.flash_attention_triton( - query_layer, key_layer, value_layer, attention_mask, attention_bias, is_causal, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_bias, + is_causal, ) else: - return self.flash_attention_cuda(query_layer, key_layer, value_layer, attention_mask, is_causal,) + return self.flash_attention_cuda( + query_layer, + key_layer, + value_layer, + attention_mask, + is_causal, + ) def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_mask, is_causal): batch_size, seqlen, nheads, _ = query_layer.shape @@ -1071,7 +1089,13 @@ def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_ if attention_bias.shape[3] == attention_mask_kv.shape[3]: attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min) - context_layer = flash_attn_func_triton(query_layer, key_layer, value_layer, attention_bias, is_causal,) + context_layer = flash_attn_func_triton( + query_layer, + key_layer, + value_layer, + attention_bias, + is_causal, + ) # [b, sq, np, hn] -> [b, np, sq, hn] context_layer = context_layer.permute(0, 2, 1, 3) diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 56496d56bc07..4e6f9e15b839 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os.path from dataclasses import MISSING, dataclass from typing import Dict, List, Optional from nemo.utils import logging +from .huggingface.huggingface_utils import get_huggingface_pretrained_lm_models_list + __all__ = ['get_tokenizer', 'get_tokenizer_list'] @@ -32,7 +33,7 @@ def get_tokenizer_list() -> List[str]: """ Returns all all supported tokenizer names """ - s = set(get_pretrained_lm_models_list()) + s = set(get_huggingface_pretrained_lm_models_list(include_external=False)) s.update(set(get_huggingface_pretrained_lm_models_list(include_external=True))) return ["sentencepiece", "char", "word"] + list(s) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 6aee365a3f60..486c9a4fe79c 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -81,11 +81,17 @@ def __init__( self.save_context_on_train_end = save_context_on_train_end self.save_optim_on_train_end = save_optim_on_train_end + ## stores the next -last checkpoint to be saved, used only when save_last = 'link' + ## this is needed because when using symlinks, we need to update the non-last checkpoint's + ## last_model_path to point to the corresponding -last version + self.future_last_model_path = "" + # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` # is called, the last element is frozen and a new element is added. self.deferred_ckpts_to_remove: List[List[str]] = [] + self.ckpts_to_link: Dict[str, str] = {} # Call the parent class constructor with the remaining kwargs. super().__init__( @@ -240,6 +246,13 @@ def __is_ckpt_ok(ckpt_path: str) -> bool: self.best_model_path = "" self.best_model_score = None + def state_dict(self): + state = super().state_dict() + ## if using symlinks, overwrite last_model_path to avoid off-by-one issues + if self.save_last == "link": + state["last_model_path"] = self.future_last_model_path + return state + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: super().load_state_dict(state_dict) self._remove_invalid_entries_from_topk() @@ -397,6 +410,25 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]: return monitor_candidates + def _link_checkpoint(self, trainer: "pl.Trainer", filepath: str, linkpath: str, override_async=False) -> None: + + ## check to see whether this step has already been saved as top_k + ## in which case we can create a symlink + ## otherwise, we have to save the checkpoint + saved_current_step = str(ckpt_to_dir(linkpath)).replace("-last", "") == str(ckpt_to_dir(filepath)) + if not saved_current_step: + self._save_checkpoint(trainer, linkpath) + return + + ## linking will happen as part of the finalize fn + if self.async_save and not override_async: + self.ckpts_to_link[str(filepath)] = str(linkpath) + return + + filepath = ckpt_to_dir(filepath) + linkpath = ckpt_to_dir(linkpath) + super()._link_checkpoint(trainer, filepath, linkpath) + def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: from nemo.utils.get_rank import is_global_rank_zero @@ -408,6 +440,13 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self._last_global_step_saved = trainer.global_step + ## manually update last_model_path so symlink is up-to-date + ## should only be done when using a symlink + if self.save_last == "link": + self.future_last_model_path = str(ckpt_to_dir(filepath)) + if not str(ckpt_to_dir(filepath)).endswith("last"): + self.future_last_model_path += "-last.ckpt" + if ema_callback is not None: if self.async_save: raise ValueError('async_save with EMA not supported') @@ -452,6 +491,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context") if self.async_save: + self._last_checkpoint_saved = filepath logging.info(f'Scheduled async checkpoint save for {filepath}') else: finalize_fn() @@ -479,6 +519,9 @@ def _cb(): logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.') + if str(filepath) in self.ckpts_to_link: + self._link_checkpoint(trainer, filepath, self.ckpts_to_link.pop(filepath), override_async=True) + # Remove checkpoints marked for removal by `self._remove_checkpoint` # For each finalization there is exactly one entry in self.deferred_ckpts_to_remove assert self.deferred_ckpts_to_remove diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 5f24d988396b..9ee1e84c7c62 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -205,8 +205,12 @@ def current_epoch_step(self) -> int: @override def remove_checkpoint(self, filepath: Union[str, Path]) -> None: # Taken from MegatronStrategy + ckpt = ckpt_to_dir(filepath) if self.is_global_zero: - shutil.rmtree(ckpt_to_dir(filepath)) + if os.path.islink(ckpt): + os.unlink(ckpt) + else: + shutil.rmtree(ckpt) @override def save_checkpoint( diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 6c0d7c8f6b04..d1e2c7dbae57 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -723,8 +723,12 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any], selective_res _optimizer_to_device(optimizer, self.root_device) def remove_checkpoint(self, filepath: Union[str, Path]) -> None: + ckpt = ckpt_to_dir(filepath) if self.is_global_zero: - shutil.rmtree(ckpt_to_dir(filepath)) + if os.path.islink(ckpt): + os.unlink(ckpt) + else: + shutil.rmtree(ckpt) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: assert self.megatron_parallel is not None diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0bd6208f11c7..eda898ea233f 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -5,7 +5,7 @@ onnx>=1.7.0 python-dateutil ruamel.yaml scikit-learn -setuptools>=65.5.1 +setuptools>=70.0.0 tensorboard text-unidecode torch diff --git a/tests/collections/llm/megatron_t5_pretraining.py b/tests/collections/llm/megatron_t5_pretraining.py index 407a1d3ab96e..67cb33e8a69f 100644 --- a/tests/collections/llm/megatron_t5_pretraining.py +++ b/tests/collections/llm/megatron_t5_pretraining.py @@ -77,6 +77,7 @@ def get_args(): ) checkpoint_callback = ModelCheckpoint( every_n_train_steps=5000, + save_optim_on_train_end=True, ) callbacks = [checkpoint_callback] diff --git a/tests/collections/nlp/test_flash_attention.py b/tests/collections/nlp/test_flash_attention.py index 4bd740011b24..f5585ddc1636 100644 --- a/tests/collections/nlp/test_flash_attention.py +++ b/tests/collections/nlp/test_flash_attention.py @@ -39,7 +39,6 @@ HAVE_FA = False try: - import pkg_resources import triton HAVE_TRITON = True @@ -80,7 +79,13 @@ def setup_class(cls): MB_SIZE = 4 GB_SIZE = 8 SEED = 1234 - trainer = Trainer(strategy=NLPDDPStrategy(), devices=GPUS, accelerator='gpu', num_nodes=1, logger=None,) + trainer = Trainer( + strategy=NLPDDPStrategy(), + devices=GPUS, + accelerator='gpu', + num_nodes=1, + logger=None, + ) initialize_model_parallel_for_nemo( world_size=trainer.world_size, diff --git a/tests/lightning/pytorch/callbacks/test_model_checkpoint.py b/tests/lightning/pytorch/callbacks/test_model_checkpoint.py new file mode 100644 index 000000000000..7e047515b2b5 --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_model_checkpoint.py @@ -0,0 +1,284 @@ +import os +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator, Optional, Sequence, Tuple + +import megatron +import pytest +import pytorch_lightning as pl +import torch +from megatron.core import ModelParallelConfig, parallel_state +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch import Tensor + +import nemo.lightning as nl +from nemo.lightning.io.mixin import IOMixin +from nemo.lightning.megatron_parallel import DataT, MegatronLossReduction, ReductionT +from nemo.lightning.pytorch.plugins import MegatronDataSampler + + +### model environment related utilities +def _reset_megatron_parallel_state(): + """Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initialized model parallel in + nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo + """ # noqa: D205, D415 + megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None + # Clean up any process groups created in testing + torch.cuda.empty_cache() + if parallel_state.is_initialized(): + parallel_state.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +@contextmanager +def reset_megatron_parallel_state() -> Iterator[None]: + """Puts you into a clean parallel state, and again tears it down at the end.""" + try: + _reset_megatron_parallel_state() + yield + finally: + _reset_megatron_parallel_state() + + +class RandomDataset(pl.LightningDataModule): + def __init__(self, size, length): + super().__init__() + self.len = length + self.data = torch.randn(length, size) + self.data_sampler = MegatronDataSampler( + seq_len=size, + micro_batch_size=2, + global_batch_size=2, + rampup_batch_size=None, + ) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + def train_dataloader(self) -> TRAIN_DATALOADERS: + return torch.utils.data.DataLoader(self.data, batch_size=2) + + def val_dataloader(self) -> EVAL_DATALOADERS: + return torch.utils.data.DataLoader(self.data, batch_size=2) + + +class PassThroughLossReduction(MegatronLossReduction): + """A class used for calculating the loss, and for logging the reduced loss across micro batches.""" + + def forward(self, batch: DataT, forward_out: Tensor) -> Tuple[Tensor, ReductionT]: + + return forward_out, forward_out + + def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor: + """Works across micro-batches. (data on single gpu). + + Note: This currently only works for logging and this loss will not be used for backpropagation. + + Args: + losses_reduced_per_micro_batch: a list of the outputs of forward + + Returns: + A tensor that is the mean of the losses. (used for logging). + """ + mse_losses = torch.stack([loss for loss in losses_reduced_per_micro_batch]) + return mse_losses.mean() + + +class ExampleModel(pl.LightningModule, IOMixin): + def __init__(self, *args, **kwargs): + super().__init__() + + ## keeps track of number of validation steps + self.count = torch.zeros((1,)) + + def configure_model(self): + + class NestedModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.l1 = torch.nn.modules.Linear(in_features=32, out_features=32) + self.bn = torch.nn.BatchNorm1d(32) + self.model_type = "test" + self.validation_step_outputs = [] + + class DummyConfig(ModelParallelConfig): + calculate_per_token_loss: bool = False + fp8: bool = False + + self.config = DummyConfig() + + self.module = NestedModel() + + def forward(self, batch): + return self.l1(self.bn(batch)).sum() + + def train_dataloader(self): + dataset = RandomDataset(32, 16) + return torch.utils.data.DataLoader(dataset, batch_size=2) + + def val_dataloader(self): + dataset = RandomDataset(32, 16) + return torch.utils.data.DataLoader(dataset, batch_size=2) + + def test_dataloader(self): + dataset = RandomDataset(32, 16) + dl = torch.utils.data.DataLoader(dataset, batch_size=2) + self._test_names = ['test_{}_'.format(idx) for idx in range(len(dl))] + return dl + + def training_step(self, batch): + return self(batch) + + def validation_step(self, batch): + ## use a dummy validation loss to ensure that loss is decreasing at each step + ## which guarantees that the -last checkpoints will be symlinks if specified + self.count += 1 + self.validation_step_outputs.append(-self.count) + return -self.count + + def test_step(self, batch): + loss = self(batch) + self.test_step_outputs.append(loss) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=1e-3) + + def on_validation_epoch_end(self): + self.log("val_loss", torch.stack(self.validation_step_outputs).mean()) + self.validation_step_outputs.clear() # free memory + + def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None: + pass + + def training_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + # This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss + return PassThroughLossReduction() + + def validation_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + return PassThroughLossReduction() + + +def setup_test(path, async_save=False, max_epochs=3): + model = ExampleModel() + + data = RandomDataset(32, 64) + + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + ) + + nemo_logger = nl.NeMoLogger( + log_dir=path, + use_datetime_version=False, + ) + + strategy = nl.MegatronStrategy( + ckpt_async_save=async_save, + replace_progress_bar=False, + ) + + trainer = nl.Trainer( + max_epochs=max_epochs, + devices=1, + val_check_interval=6, + log_every_n_steps=4, + callbacks=nl.ModelCheckpoint( + monitor="val_loss", + save_top_k=3, + save_on_train_epoch_end=True, + save_context_on_train_end=False, + filename=f'{{step}}-{{epoch}}-{{val_loss}}-{{consumed_samples}}', + save_last="link", + ), + strategy=strategy, + ) + nemo_logger.setup(trainer) + resume.setup(trainer) + + return data, model, trainer + + +def get_final_checkpoint(checkpoint_dir): + dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()] + last_checkpoints = [d for d in dist_checkpoints if d.match("*last")] + + assert len(last_checkpoints) == 1 ## should only have one -last checkpoint + final_ckpt = last_checkpoints[0] + + top_k_checkpoints = [d for d in dist_checkpoints if d not in last_checkpoints] + + return final_ckpt, top_k_checkpoints + + +class TestLinkCheckpoint: + + @pytest.mark.unit + @pytest.mark.run_only_on("GPU") + def test_link_ckpt(self, tmpdir): + """Test to ensure that we always keep top_k checkpoints, even after resuming.""" + + with reset_megatron_parallel_state(): + tmp_path = tmpdir / "link_ckpt_test" + data, model, trainer = setup_test(tmp_path, async_save=False) + + trainer.fit(model, data) + + checkpoint_dir = Path(tmp_path / "default" / "checkpoints") + final_ckpt, top_k_checkpoints = get_final_checkpoint(checkpoint_dir) + assert os.path.islink(final_ckpt) + + ## make sure we're saving the expected number of checkpoints + assert len(top_k_checkpoints) == 3 + + link = final_ckpt.resolve() + assert str(final_ckpt).replace("-last", "") == str(link) + + @pytest.mark.unit + @pytest.mark.run_only_on("GPU") + def test_link_ckpt_async(self, tmpdir): + """Test to ensure that we always keep top_k checkpoints, even after resuming.""" + + with reset_megatron_parallel_state(): + tmp_path = tmpdir / "async_link_ckpt_test" + data, model, trainer = setup_test(tmp_path, async_save=True) + + trainer.fit(model, data) + + checkpoint_dir = Path(tmp_path / "default" / "checkpoints") + final_ckpt, top_k_checkpoints = get_final_checkpoint(checkpoint_dir) + assert os.path.islink(final_ckpt) + assert len(top_k_checkpoints) == 3 + + link = final_ckpt.resolve() + assert str(final_ckpt).replace("-last", "") == str(link) + + @pytest.mark.unit + @pytest.mark.run_only_on("GPU") + def test_restore_async(self, tmpdir): + """Test to ensure that we always keep top_k checkpoints, even after resuming.""" + + with reset_megatron_parallel_state(): + tmp_path = tmpdir / "async_link_ckpt_test" + data, model, trainer = setup_test(tmp_path, async_save=True, max_epochs=3) + + trainer.fit(model, data) + + ## reinitialize + data, model, trainer = setup_test(tmp_path, async_save=True, max_epochs=6) + + trainer.fit(model, data) + + checkpoint_dir = Path(tmp_path / "default" / "checkpoints") + final_ckpt, top_k_checkpoints = get_final_checkpoint(checkpoint_dir) + assert os.path.islink(final_ckpt) + assert len(top_k_checkpoints) == 3 + + epoch = str(final_ckpt).split('epoch=')[1][0] + assert int(epoch) == 5 ## make sure we're running the correct number of epochs diff --git a/tutorials/multimodal/Multimodal Data Preparation.ipynb b/tutorials/multimodal/Multimodal Data Preparation.ipynb index 713f001937b8..6befc239c640 100644 --- a/tutorials/multimodal/Multimodal Data Preparation.ipynb +++ b/tutorials/multimodal/Multimodal Data Preparation.ipynb @@ -475,9 +475,7 @@ "id": "9d6804d4", "metadata": {}, "outputs": [], - "source": [ - "! wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/vae/diffusion_pytorch_model.bin" - ] + "source": "! wget https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/vae/diffusion_pytorch_model.bin" }, { "attachments": {},