diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index e41e96cfa794..61f809312bef 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -2943,7 +2943,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_t5_pretraining.py \ + python examples/nlp/language_modeling/megatron_t5_pretraining.py \ trainer.devices=2 \ trainer.log_every_n_steps=1 \ trainer.max_epochs=null \ @@ -2975,7 +2975,7 @@ jobs: +model.data.data_impl_kwargs.workers=null \ +model.data.data_impl_kwargs.sort_dataset_paths=False - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_t5_pretraining.py \ + python examples/nlp/language_modeling/megatron_t5_pretraining.py \ trainer.devices=2 \ trainer.log_every_n_steps=1 \ trainer.max_epochs=null \ @@ -3398,8 +3398,8 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_t5_eval.py \ - --model_file /home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + python examples/nlp/language_modeling/megatron_t5_eval.py \ + --model_file /home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m_padding_attnmasktype.nemo \ --prompt "How do I fix my GPU memory issue? I am seeing out of memory." \ --tensor_model_parallel_size 1 @@ -3410,7 +3410,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/tuning/megatron_t5_finetuning.py \ + python examples/nlp/language_modeling/tuning/megatron_t5_finetuning.py \ trainer.devices=2 \ trainer.log_every_n_steps=1 \ trainer.max_epochs=9999 \ @@ -3421,7 +3421,7 @@ jobs: exp_manager.exp_dir=/tmp/nlp_mcore_t5_lora_tuning_tp2 \ model.pipeline_model_parallel_size=1 \ model.tensor_model_parallel_size=2 \ - model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m_padding_attnmasktype.nemo \ model.peft.peft_scheme=lora \ model.answer_only_loss=True \ model.micro_batch_size=1 \ @@ -3433,8 +3433,8 @@ jobs: model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ model.data.validation_ds.names=[quarel] - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/tuning/megatron_t5_generate.py \ - model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + python examples/nlp/language_modeling/tuning/megatron_t5_generate.py \ + model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m_padding_attnmasktype.nemo \ model.peft.restore_from_path=/tmp/nlp_mcore_t5_lora_tuning_tp2/megatron_t5_peft_lora_tuning/checkpoints/megatron_t5_peft_lora_tuning.nemo \ model.peft.restore_from_ckpt_name=null \ model.peft.restore_from_hparams_path=null \ @@ -3852,14 +3852,14 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_pretraining.py \ + python tests/collections/llm/megatron_t5_pretraining.py \ --devices=2 \ --max-steps=3 \ --experiment-dir=tests/collections/llm/t5_pretrain_results/${{ github.run_id }} \ --data-path=/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document \ --index-mapping-dir=tests/collections/llm/t5_index_mappings/${{ github.run_id }} - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_pretraining.py \ + python tests/collections/llm/megatron_t5_pretraining.py \ --devices=2 \ --max-steps=6 \ --experiment-dir=tests/collections/llm/t5_pretrain_results/${{ github.run_id }} \ @@ -3876,11 +3876,11 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_finetuning.py \ + python tests/collections/llm/megatron_t5_finetuning.py \ --devices=2 \ --max-steps=250 \ --experiment-dir=tests/collections/llm/t5_finetune_results/${{ github.run_id }} \ - --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_150steps + --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_padding_attnmasktype_150steps AFTER_SCRIPT: | rm -rf tests/collections/llm/t5_finetune_results/${{ github.run_id }} @@ -3891,12 +3891,12 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_finetuning.py \ + python tests/collections/llm/megatron_t5_finetuning.py \ --devices=2 \ --max-steps=250 \ --peft=lora \ --experiment-dir=tests/collections/llm/t5_peft_results/${{ github.run_id }} \ - --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_150steps + --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_padding_attnmasktype_150steps AFTER_SCRIPT: | rm -rf tests/collections/llm/t5_peft_results/${{ github.run_id }} diff --git a/.github/workflows/monitor-vms.yml b/.github/workflows/monitor-vms.yml index 6795f87abf68..0bb54524847a 100644 --- a/.github/workflows/monitor-vms.yml +++ b/.github/workflows/monitor-vms.yml @@ -27,7 +27,7 @@ jobs: | jq -c '[ .runners[] | select(.status == "online") - | select(.name | contains("gpu")) + | select(.name | contains("cpu") | not) | { "vm": .name, "n_gpus": [ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c11629776c40..03474251f995 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -23,7 +23,7 @@ on: jobs: release: - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.10.0 + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.12.3 with: release-ref: ${{ inputs.release-ref }} image-name: nemo_container @@ -39,3 +39,4 @@ jobs: TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} SLACK_RELEASE_ENDPOINT: ${{ secrets.SLACK_RELEASE_ENDPOINT }} + PAT: ${{ secrets.PAT }} diff --git a/.secrets.baseline b/.secrets.baseline index c26f70775c5a..1e4832b40075 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -90,6 +90,10 @@ { "path": "detect_secrets.filters.allowlist.is_line_allowlisted" }, + { + "path": "detect_secrets.filters.common.is_baseline_file", + "filename": ".secrets.baseline" + }, { "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies", "min_level": 2 @@ -273,7 +277,7 @@ "filename": "scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py", "hashed_secret": "e0308bd21bffc156d79208f9ecf130370a015002", "is_verified": false, - "line_number": 460 + "line_number": 471 } ], "scripts/dataset_processing/nlp/intent_and_slot/assistant_utils.py": [ @@ -2083,5 +2087,5 @@ } ] }, - "generated_at": "2024-10-25T13:43:17Z" + "generated_at": "2024-11-14T09:37:19Z" } diff --git a/Dockerfile.ci b/Dockerfile.ci index 5858f0aadf5b..e1b78547325a 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -54,7 +54,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.19.0 -ARG MCORE_TAG=aded519cfb1de2abf96f36ca059f992294b7876f +ARG MCORE_TAG=c1728c12f1f1cdbb786e52f1ffe512295d76bef3 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ diff --git a/docs/source/nlp/distillation.rst b/docs/source/nlp/distillation.rst deleted file mode 100644 index 22b2f3dd8a1c..000000000000 --- a/docs/source/nlp/distillation.rst +++ /dev/null @@ -1,58 +0,0 @@ -.. _megatron_distillation: - -Distillation -========================== - -Knowledge Distillation (KD) --------------------------------- - -KD involves using information from an existing trained model to train a second (usually smaller, faster) model, thereby "distilling" knowledge from one to the other. - -Distillation has two primary benefits: faster convergence and higher end accuracy than traditional training. - -In NeMo, distillation is enabled by the `NVIDIA TensorRT Model Optimizer (ModelOpt) `_ library -- a library to optimize deep-learning models for inference on GPUs. - -The logits-distillation process consists of the following steps: - -1. Loading both student and teacher model checkpoints (must support same parallelism strategy, if any) -2. Training until convergence, where forward passes are run on both models (and backward only on student), performing a specific loss function between the logits. -3. Saving the final student model. - - -Example -^^^^^^^ -The example below shows how to run the distillation script for LLama models. - -The script must be launched correctly with the number of processes equal to tensor parallelism. This is achieved with the ``torchrun`` command below: - -.. code-block:: bash - - STUDENT_CKPT="path/to/student.nemo" # can also be None (will use default architecture found in examples/nlp/language_modeling/conf/megatron_llama_distill.yaml) - TEACHER_CKPT="path/to/teacher.nemo" - TOKENIZER="path/to/tokenizer.model" - DATA_PATHS="[1.0,path/to/tokenized/data]" - FINAL_SAVE_FILE="final_checkpoint.nemo" - TP=4 - - NPROC=$TP - launch_config="torchrun --nproc_per_node=$NPROC" - - ${launch_config} examples/nlp/language_modeling/megatron_gpt_distillation.py \ - model.restore_from_path=$STUDENT_CKPT \ - model.kd_teacher_restore_from_path=$TEACHER_CKPT \ - model.tensor_model_parallel_size=$TP \ - model.tokenizer.model=$TOKENIZER \ - model.data.data_prefix=$DATA_PATHS \ - model.nemo_path=$FINAL_SAVE_FILE \ - trainer.precision=bf16 \ - trainer.devices=$NPROC - -For large models, the command can be used in multi-node setting. For example, this can be done with `NeMo Framework Launcher `_ using Slurm. - - -Limitations -^^^^^^^^^^^ -* Only Megatron Core-based GPT models are supported -* Only logit-pair distillation is supported for now -* Pipeline parallelism not yet supported -* FSDP strategy not yet supported diff --git a/docs/source/nlp/nemo_megatron/model_distillation/drop_layers.rst b/docs/source/nlp/nemo_megatron/model_distillation/drop_layers.rst deleted file mode 100644 index 3dc008945cc9..000000000000 --- a/docs/source/nlp/nemo_megatron/model_distillation/drop_layers.rst +++ /dev/null @@ -1,67 +0,0 @@ -.. _drop_layers: - -Drop Model Layers ------------------ - -To trim the model layers, use the following script: - -.. code-block:: bash - - python -m torch.distributed.launch --nproc_per_node= * \ - /NeMo/examples/nlp/language_modeling/megatron_gpt_drop_layers.py \ - --path_to_nemo /path/to/model.nemo \ - --path_to_save /path/to/save/trimmed_model.nemo \ - --tensor_model_parallel_size \ - --pipeline_model_parallel_size \ - --gpus_per_node \ - --drop_layers 1 2 3 4 - -**Note:** layer indices start from 1. - -To save trimmed model in ``zarr`` checkpoint format, add the following flag to the command above: - -.. code-block:: bash - - --zarr - -**Note:** the ``zarr`` checkpoint format is deprecated. - -Validate Trimmed Model ----------------------- - -To validate the trimmed model, use the following script: - -.. code-block:: bash - - python /NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - --config-path=/path/to/folder/with/model/config \ - --config-name=model_config.yaml \ - trainer.limit_val_batches= \ - model.restore_from_path=/path/to/trimmed_model.nemo \ - model.skip_train=True \ - model.data.data_impl=mock \ - model.data.data_prefix=[] - -To use a specific dataset instead of a mock dataset, modify the ``model.data`` parameters as follows: - -.. code-block:: bash - - model.data.data_impl=mmap \ - model.data.data_prefix=["path/to/datafile1", "path/to/datafile2"] - -Validate Original Model ------------------------ - -To validate the original model without specific layers, use the following script: - -.. code-block:: bash - - python /NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - --config-path=/path/to/folder/with/model/config \ - --config-name=model_config.yaml \ - trainer.limit_val_batches= \ - model.restore_from_path=/path/to/original_model.nemo \ - model.skip_train=True \ - model.data.data_impl=mock \ - model.data.data_prefix=[] \ - model.drop_layers=[1,2,3,4] diff --git a/docs/source/nlp/punctuation_and_capitalization.rst b/docs/source/nlp/punctuation_and_capitalization.rst index 4be0d2151d8e..d67332eb00c1 100755 --- a/docs/source/nlp/punctuation_and_capitalization.rst +++ b/docs/source/nlp/punctuation_and_capitalization.rst @@ -240,7 +240,7 @@ An example of a config file is - trainer config - - Parameters of - `pytorch_lightning.Trainer `_. + `lightning.pytorch.Trainer `_. * - **exp_manager** - exp manager config - diff --git a/docs/source/starthere/fundamentals.rst b/docs/source/starthere/fundamentals.rst index e3014e0f5a03..f486bf3d6e49 100644 --- a/docs/source/starthere/fundamentals.rst +++ b/docs/source/starthere/fundamentals.rst @@ -116,7 +116,7 @@ Below is an example training script for our ``ExampleEncDecModel`` model. We hig :linenos: :emphasize-lines: 10, 11, 12 - import pytorch_lightning as pl + import lightning.pytorch as pl from nemo.collections.path_to_model_class import ExampleEncDecModel from nemo.core.config import hydra_runner diff --git a/examples/asr/asr_adapters/eval_asr_adapter.py b/examples/asr/asr_adapters/eval_asr_adapter.py index bc5947f26aaf..b35cf33a6c0e 100644 --- a/examples/asr/asr_adapters/eval_asr_adapter.py +++ b/examples/asr/asr_adapters/eval_asr_adapter.py @@ -36,7 +36,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf, open_dict from nemo.collections.asr.models import ASRModel diff --git a/examples/asr/asr_adapters/train_asr_adapter.py b/examples/asr/asr_adapters/train_asr_adapter.py index 3f82ef8fe554..253672e3eb89 100644 --- a/examples/asr/asr_adapters/train_asr_adapter.py +++ b/examples/asr/asr_adapters/train_asr_adapter.py @@ -84,7 +84,7 @@ import os from dataclasses import is_dataclass -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf, open_dict from nemo.collections.asr.models import ASRModel diff --git a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py index 8188bcced14d..1e63a9d820be 100644 --- a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py +++ b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py @@ -49,7 +49,7 @@ from dataclasses import dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py index 87370d278f98..ccea94f41f83 100644 --- a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py +++ b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py @@ -42,7 +42,7 @@ from dataclasses import dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py index e6e84cdfa6c4..c31fa2b9d812 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py @@ -64,7 +64,7 @@ from dataclasses import dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf, open_dict diff --git a/examples/asr/asr_ctc/speech_to_text_ctc.py b/examples/asr/asr_ctc/speech_to_text_ctc.py index 87b1b11633f7..ccdf3a5e09ea 100644 --- a/examples/asr/asr_ctc/speech_to_text_ctc.py +++ b/examples/asr/asr_ctc/speech_to_text_ctc.py @@ -68,7 +68,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecCTCModel diff --git a/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py b/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py index b4e3be5f650a..997cd6e52d5b 100644 --- a/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py +++ b/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py @@ -64,7 +64,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE diff --git a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py index 796005a8fcee..ffda4c554a49 100644 --- a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py +++ b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py @@ -58,7 +58,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel diff --git a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py index 423e005d8f02..02f43f93e2c7 100644 --- a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py +++ b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py @@ -69,7 +69,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecHybridRNNTCTCModel diff --git a/examples/asr/asr_transducer/speech_to_text_rnnt.py b/examples/asr/asr_transducer/speech_to_text_rnnt.py index 5b4f1e8a985d..2fab3ac137e6 100644 --- a/examples/asr/asr_transducer/speech_to_text_rnnt.py +++ b/examples/asr/asr_transducer/speech_to_text_rnnt.py @@ -67,7 +67,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecRNNTModel diff --git a/examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py index 1fffea55686f..d18313acc9a6 100644 --- a/examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py +++ b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py @@ -59,7 +59,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecRNNTBPEModel diff --git a/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py b/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py index b435d418fda2..acd7a8632822 100644 --- a/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py +++ b/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py @@ -49,7 +49,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.hybrid_asr_tts_models import ASRWithTTSModel diff --git a/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py b/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py index 99bc41ba966b..c1692cf6234f 100644 --- a/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py +++ b/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py @@ -45,7 +45,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.hybrid_asr_tts_models import ASRWithTTSModel diff --git a/examples/asr/conf/asr_adapters/asr_adaptation.yaml b/examples/asr/conf/asr_adapters/asr_adaptation.yaml index b9a2a003217e..bae166d18782 100644 --- a/examples/asr/conf/asr_adapters/asr_adaptation.yaml +++ b/examples/asr/conf/asr_adapters/asr_adaptation.yaml @@ -182,7 +182,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: null diff --git a/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml b/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml index 958e6d23375c..d03b2eacfec4 100644 --- a/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml +++ b/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml @@ -182,7 +182,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: null diff --git a/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml index 3b5717efddf9..1ae64a341e16 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml @@ -81,7 +81,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml index f111573f21eb..c044d3c8d7a8 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml @@ -145,7 +145,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml index 4c80d2f2e9d4..564f4b176e64 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml @@ -172,7 +172,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml index 0796a60260a1..6962c03ebe60 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml @@ -177,7 +177,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml index 4edcc38396fa..1531bf380b6d 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml @@ -228,7 +228,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml index 97b64ef93402..4cb508b0aff3 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml @@ -234,7 +234,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml index ea6094380856..fd5f34aa43cb 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml @@ -198,7 +198,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml index 9e2c1a876864..deb7b7ca613a 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml @@ -251,7 +251,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml index daef1ed67a9f..6d89a6a52dfb 100644 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml @@ -245,7 +245,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml index 96aee4af1803..7e6b9c4aa7b4 100644 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml @@ -250,7 +250,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml index 4ba55e368bb9..12a21c6fba6c 100644 --- a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml @@ -224,7 +224,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml index ed2ad8ca9c0d..65f657b5416e 100644 --- a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml +++ b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml @@ -229,7 +229,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml index 773a500ef2db..df511883ce80 100644 --- a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml @@ -169,7 +169,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 @@ -204,4 +204,4 @@ exp_manager: create_wandb_logger: false wandb_logger_kwargs: name: null - project: null \ No newline at end of file + project: null diff --git a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml index fec2a2839efa..0218136cbdbd 100644 --- a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml @@ -223,7 +223,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml index 3d1a8c8bdf47..50446dfd9467 100644 --- a/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml +++ b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml @@ -249,7 +249,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/experimental/k2/align_speech_parallel.py b/examples/asr/experimental/k2/align_speech_parallel.py index abfffa0cdfdb..cf07fb998e95 100644 --- a/examples/asr/experimental/k2/align_speech_parallel.py +++ b/examples/asr/experimental/k2/align_speech_parallel.py @@ -77,7 +77,7 @@ from dataclasses import dataclass, field, is_dataclass from typing import Optional -import pytorch_lightning as ptl +import lightning.pytorch as ptl import torch from omegaconf import MISSING, OmegaConf diff --git a/examples/asr/experimental/k2/speech_to_text_bpe.py b/examples/asr/experimental/k2/speech_to_text_bpe.py index ee3924c7b8ac..8a941200770f 100644 --- a/examples/asr/experimental/k2/speech_to_text_bpe.py +++ b/examples/asr/experimental/k2/speech_to_text_bpe.py @@ -74,7 +74,7 @@ model.graph_module_cfg.background_cfg.intersect_pruned=False \ model.graph_module_cfg.background_cfg.boost_coeff=0.0 """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.configs.k2_sequence_models_config import EncDecK2SeqModelConfig diff --git a/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py b/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py index a0031fba082d..973be0cbd477 100644 --- a/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py +++ b/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py @@ -63,7 +63,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecK2RnntSeqModelBPE diff --git a/examples/asr/experimental/structured/speech_to_text_hybrid.py b/examples/asr/experimental/structured/speech_to_text_hybrid.py index 26530631498f..e6126c47305f 100644 --- a/examples/asr/experimental/structured/speech_to_text_hybrid.py +++ b/examples/asr/experimental/structured/speech_to_text_hybrid.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.asr.models import EncDecCTCModel, configs from nemo.core.config import hydra_runner diff --git a/examples/asr/experimental/structured/speech_to_text_structured.py b/examples/asr/experimental/structured/speech_to_text_structured.py index 366c6d831a7d..55934c00322e 100644 --- a/examples/asr/experimental/structured/speech_to_text_structured.py +++ b/examples/asr/experimental/structured/speech_to_text_structured.py @@ -14,7 +14,7 @@ from dataclasses import asdict -import pytorch_lightning as pl +import lightning.pytorch as pl import nemo.collections.asr as nemo_asr from nemo.collections.asr.models import EncDecCTCModel, configs @@ -64,7 +64,13 @@ ), # ... repeat 14 more times nemo_asr.modules.conv_asr.JasperEncoderConfig( - filters=1024, repeat=1, kernel=[1], stride=[1], dilation=[1], dropout=cfg.model.dropout, residual=False, + filters=1024, + repeat=1, + kernel=[1], + stride=[1], + dilation=[1], + dropout=cfg.model.dropout, + residual=False, ), ] diff --git a/examples/asr/experimental/structured/speech_to_text_structured_v2.py b/examples/asr/experimental/structured/speech_to_text_structured_v2.py index e8a865a9877a..146da425fb9b 100644 --- a/examples/asr/experimental/structured/speech_to_text_structured_v2.py +++ b/examples/asr/experimental/structured/speech_to_text_structured_v2.py @@ -14,7 +14,7 @@ from dataclasses import asdict -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.asr.models import EncDecCTCModel, configs from nemo.core.config import modelPT, optimizers, schedulers diff --git a/examples/asr/speech_classification/speech_to_frame_label.py b/examples/asr/speech_classification/speech_to_frame_label.py index 04fcbdd1b61c..39a8e4415de5 100644 --- a/examples/asr/speech_classification/speech_to_frame_label.py +++ b/examples/asr/speech_classification/speech_to_frame_label.py @@ -39,7 +39,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.classification_models import EncDecFrameClassificationModel diff --git a/examples/asr/speech_classification/speech_to_label.py b/examples/asr/speech_classification/speech_to_label.py index b3deb5a4e7e5..810d2b5e7bdf 100644 --- a/examples/asr/speech_classification/speech_to_label.py +++ b/examples/asr/speech_classification/speech_to_label.py @@ -143,7 +143,7 @@ https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speech_classification/results.html# """ -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/speech_multitask/speech_to_text_aed.py b/examples/asr/speech_multitask/speech_to_text_aed.py index 0c13e5289d86..943ecee59bfc 100644 --- a/examples/asr/speech_multitask/speech_to_text_aed.py +++ b/examples/asr/speech_multitask/speech_to_text_aed.py @@ -50,7 +50,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecMultiTaskModel diff --git a/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py b/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py index 3a256c7ab2d3..8bd56aa63450 100644 --- a/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py +++ b/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py @@ -14,7 +14,7 @@ from collections import OrderedDict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py index 1ea88d696643..e1c740e66412 100644 --- a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py +++ b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.ssl_models import EncDecDenoiseMaskedTokenPredModel diff --git a/examples/asr/speech_pretraining/speech_pre_training.py b/examples/asr/speech_pretraining/speech_pre_training.py index cec9444096c3..0c94099442a6 100644 --- a/examples/asr/speech_pretraining/speech_pre_training.py +++ b/examples/asr/speech_pretraining/speech_pre_training.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel diff --git a/examples/asr/speech_to_text_finetune.py b/examples/asr/speech_to_text_finetune.py index 36a7bdc3bbdc..6b53446622ee 100644 --- a/examples/asr/speech_to_text_finetune.py +++ b/examples/asr/speech_to_text_finetune.py @@ -54,7 +54,7 @@ https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations """ import time -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import ASRModel diff --git a/examples/asr/speech_translation/speech_to_text_transformer.py b/examples/asr/speech_translation/speech_to_text_transformer.py index ac4dc4334164..bb7e0b3e4461 100644 --- a/examples/asr/speech_translation/speech_to_text_transformer.py +++ b/examples/asr/speech_translation/speech_to_text_transformer.py @@ -40,7 +40,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecTransfModelBPE diff --git a/examples/asr/speech_translation/translate_speech.py b/examples/asr/speech_translation/translate_speech.py index 53599e1b3511..76c8c096527f 100644 --- a/examples/asr/speech_translation/translate_speech.py +++ b/examples/asr/speech_translation/translate_speech.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, is_dataclass from typing import List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index a543fcf5e252..f1d61edc990e 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field, is_dataclass from typing import List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf, open_dict diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index eb905d3e91b0..bdf54ea67f7d 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -75,7 +75,7 @@ from dataclasses import dataclass, is_dataclass from typing import Optional -import pytorch_lightning as ptl +import lightning.pytorch as ptl import torch from omegaconf import MISSING, OmegaConf diff --git a/examples/audio/audio_to_audio_train.py b/examples/audio/audio_to_audio_train.py index cef46dcf20b6..4d71e75176c9 100644 --- a/examples/audio/audio_to_audio_train.py +++ b/examples/audio/audio_to_audio_train.py @@ -28,7 +28,7 @@ """ from enum import Enum -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/audio/process_audio.py b/examples/audio/process_audio.py index ec88bda34954..8657d53ef957 100644 --- a/examples/audio/process_audio.py +++ b/examples/audio/process_audio.py @@ -20,7 +20,7 @@ from pathlib import Path from typing import List, Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/llm/peft/hf.py b/examples/llm/peft/hf.py index 97f21d6c253e..5b24c22ab79d 100644 --- a/examples/llm/peft/hf.py +++ b/examples/llm/peft/hf.py @@ -13,7 +13,7 @@ # limitations under the License. import fiddle as fdl -from pytorch_lightning.loggers import WandbLogger +from lightning.pytorch.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 7d4cde7866a2..20f915f87287 100644 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -13,8 +13,8 @@ # limitations under the License. import fiddle as fdl -import pytorch_lightning as pl -from pytorch_lightning.loggers import WandbLogger +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger from torch.utils.data import DataLoader from nemo import lightning as nl diff --git a/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py b/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py index d02b737c750a..874d62dc63c9 100644 --- a/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py +++ b/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py @@ -34,10 +34,10 @@ from collections import OrderedDict import torch +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from llava import LlavaLlamaForCausalLM from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from safetensors import safe_open from transformers import LlamaTokenizer diff --git a/examples/multimodal/speech_llm/export/extract_salm_weights.py b/examples/multimodal/speech_llm/export/extract_salm_weights.py index 0698a411110e..24c7aec3bb4d 100644 --- a/examples/multimodal/speech_llm/export/extract_salm_weights.py +++ b/examples/multimodal/speech_llm/export/extract_salm_weights.py @@ -18,9 +18,9 @@ import tempfile import torch +from lightning.pytorch.trainer.trainer import Trainer from megatron.core import dist_checkpointing from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.multimodal.speech_llm.modules.perception_modules import AudioPerceptionModule from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/multimodal/text_to_image/controlnet/controlnet_train.py b/examples/multimodal/text_to_image/controlnet/controlnet_train.py index 2bb8b66cac1a..14e7e62a1cc7 100644 --- a/examples/multimodal/text_to_image/controlnet/controlnet_train.py +++ b/examples/multimodal/text_to_image/controlnet/controlnet_train.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.multimodal.models.text_to_image.controlnet.controlnet import MegatronControlNet from nemo.collections.multimodal.models.text_to_image.controlnet.util import ImageLogger diff --git a/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py b/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py index cebf159eb870..c50ad439eaec 100644 --- a/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py +++ b/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py @@ -27,10 +27,10 @@ from argparse import ArgumentParser import torch -from lightning_fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.multimodal.models.text_to_image.controlnet.controlnet import MegatronControlNet from nemo.collections.multimodal.models.text_to_image.stable_diffusion.diffusion_engine import MegatronDiffusionEngine diff --git a/examples/multimodal/text_to_image/dreambooth/dreambooth_lora_infer.py b/examples/multimodal/text_to_image/dreambooth/dreambooth_lora_infer.py index 52f0aa2940d2..e1d050f83939 100644 --- a/examples/multimodal/text_to_image/dreambooth/dreambooth_lora_infer.py +++ b/examples/multimodal/text_to_image/dreambooth/dreambooth_lora_infer.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf import open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion from nemo.collections.multimodal.parts.stable_diffusion.pipeline import pipeline @@ -48,7 +48,10 @@ def model_cfg_modifier(model_cfg): plugins = [] plugins.append(TorchElasticEnvironment()) - strategy = NLPDDPStrategy(no_ddp_communication_hook=True, find_unused_parameters=False,) + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) model = MegatronLatentDiffusion(model_cfg, trainer=trainer) diff --git a/examples/multimodal/text_to_image/imagen/generate_fid_images.py b/examples/multimodal/text_to_image/imagen/generate_fid_images.py index ea743e3e1d06..7d2df372b545 100644 --- a/examples/multimodal/text_to_image/imagen/generate_fid_images.py +++ b/examples/multimodal/text_to_image/imagen/generate_fid_images.py @@ -15,7 +15,7 @@ import os import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ImagenPipeline from nemo.core.config import hydra_runner @@ -79,7 +79,10 @@ def main(cfg): seeds = [local_task_id * chunk_size + batch_idx * batch_size + idx for idx in range(len(batch_captions))] with torch.no_grad(): images, all_res_images, *_ = pipeline( - prompts=batch_captions, seed=seeds, single_batch_mode=True, classifier_free_guidance=current_node_cfg, + prompts=batch_captions, + seed=seeds, + single_batch_mode=True, + classifier_free_guidance=current_node_cfg, ) if cfg.fid.save_all_res: diff --git a/examples/multimodal/text_to_image/imagen/imagen_generate_images.py b/examples/multimodal/text_to_image/imagen/imagen_generate_images.py index bc002052a989..06b324367a52 100644 --- a/examples/multimodal/text_to_image/imagen/imagen_generate_images.py +++ b/examples/multimodal/text_to_image/imagen/imagen_generate_images.py @@ -16,8 +16,8 @@ import pickle import torch +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ( ImagenPipeline, @@ -65,7 +65,11 @@ def main(inference_config): seed = batch_idx + chuncksize with torch.no_grad(): - images, all_res_images, throughput = pipeline(prompts=batch_captions, seed=seeds, single_batch_mode=True,) + images, all_res_images, throughput = pipeline( + prompts=batch_captions, + seed=seeds, + single_batch_mode=True, + ) for outpath, one_res in zip(outpaths, all_res_images): for idx, (caption, image) in enumerate(zip(batch_captions, one_res[0])): diff --git a/examples/multimodal/text_to_image/imagen/imagen_infer.py b/examples/multimodal/text_to_image/imagen/imagen_infer.py index 0fb291729596..9ce680cf4b09 100644 --- a/examples/multimodal/text_to_image/imagen/imagen_infer.py +++ b/examples/multimodal/text_to_image/imagen/imagen_infer.py @@ -14,8 +14,8 @@ import os +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ( ImagenPipeline, diff --git a/examples/multimodal/text_to_image/imagen/imagen_training.py b/examples/multimodal/text_to_image/imagen/imagen_training.py index 23c1c9c1a1d7..211299156b69 100644 --- a/examples/multimodal/text_to_image/imagen/imagen_training.py +++ b/examples/multimodal/text_to_image/imagen/imagen_training.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf.omegaconf import OmegaConf, open_dict from torch._dynamo import disable diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py index 0877d4eb4b2f..0d83a8daab9f 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf import open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion from nemo.collections.multimodal.parts.stable_diffusion.pipeline import pipeline @@ -45,7 +45,10 @@ def model_cfg_modifier(model_cfg): plugins = [] plugins.append(TorchElasticEnvironment()) - strategy = NLPDDPStrategy(no_ddp_communication_hook=True, find_unused_parameters=False,) + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) model = MegatronLatentDiffusion(model_cfg, trainer=trainer) diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py index 44412aee0d14..4ef22b69aa64 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py @@ -13,10 +13,11 @@ # limitations under the License. import sys + import torch import torch._dynamo.config as dynamo_config +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.multimodal.models.text_to_image.stable_diffusion.diffusion_engine import MegatronDiffusionEngine from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder diff --git a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py index 178140aac828..abc987e07097 100644 --- a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py +++ b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py @@ -45,9 +45,9 @@ import einops import open_clip import torch +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPModel from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel diff --git a/examples/multimodal/x_to_nerf/benchmark_callback.py b/examples/multimodal/x_to_nerf/benchmark_callback.py index fd7d5afdc5bc..2db78d1d385a 100644 --- a/examples/multimodal/x_to_nerf/benchmark_callback.py +++ b/examples/multimodal/x_to_nerf/benchmark_callback.py @@ -15,7 +15,7 @@ import time from typing import Optional -from pytorch_lightning import Callback, LightningModule, Trainer +from lightning.pytorch import Callback, LightningModule, Trainer from nemo.utils import logging diff --git a/examples/multimodal/x_to_nerf/data.py b/examples/multimodal/x_to_nerf/data.py index fe7c47abc64b..b8dfd3aa536b 100644 --- a/examples/multimodal/x_to_nerf/data.py +++ b/examples/multimodal/x_to_nerf/data.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf.omegaconf import DictConfig from torch.utils.data import DataLoader diff --git a/examples/multimodal/x_to_nerf/main.py b/examples/multimodal/x_to_nerf/main.py index 5d7f616a3165..f3c8a6949867 100644 --- a/examples/multimodal/x_to_nerf/main.py +++ b/examples/multimodal/x_to_nerf/main.py @@ -13,8 +13,8 @@ # limitations under the License. from hydra.utils import get_class, instantiate +from lightning.pytorch import Trainer, seed_everything from omegaconf.omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer, seed_everything from nemo.core.config import hydra_runner from nemo.utils import logging diff --git a/examples/multimodal_autoregressive/README.md b/examples/multimodal_autoregressive/README.md new file mode 100644 index 000000000000..5934074a7d17 --- /dev/null +++ b/examples/multimodal_autoregressive/README.md @@ -0,0 +1,3 @@ +### MULTIMODAL AUTOREGRESSIVE GENERTION + +For information on how to get started with autoregressive generation for multimodal datasets using discrete tokenizers follow this [guide](nemo/collections/multimodal_autoregressive/data/README.md) diff --git a/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml b/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml new file mode 100644 index 000000000000..806800c96155 --- /dev/null +++ b/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml @@ -0,0 +1,36 @@ +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: True # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: ["<|extra_204|>"] # generation will stop when one of these tokens is generated + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + use_distributed_sampler: False + +tensor_model_parallel_size: -1 +pipeline_model_parallel_size: -1 +pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory +image_encoder: Cosmos-Tokenizer-DV8x16x16 +gpt_model_file: null # GPT nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +captions: # prompts for GPT inference + - "a drawing of a green pokemon with red eyes" + - "a red pokemon with green eyes" + - "a cartoon fish with a big smile" +images_output_path: null # Path to the directory to store the output images + diff --git a/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml b/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml new file mode 100644 index 000000000000..c392f5dcc5c2 --- /dev/null +++ b/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml @@ -0,0 +1,32 @@ +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: False # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: ["<|extra_204|>"] # generation will stop when one of these tokens is generated + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + use_distributed_sampler: False + +tensor_model_parallel_size: -1 +pipeline_model_parallel_size: -1 +pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory +gpt_model_file: null # GPT nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +images_path: # prompts for GPT inference + - "/path/to/image1" + - "/path/to/image2" diff --git a/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py b/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py new file mode 100644 index 000000000000..ae8dddb29553 --- /dev/null +++ b/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py @@ -0,0 +1,196 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import math +import os +import re + +import torch +import torchvision +from examples.nlp.language_modeling.megatron_gpt_eval import ( + load_model_from_config, + remove_padded_prompts, + round_to_mult, +) +from pytorch_lightning.trainer.trainer import Trainer + +# pylint: disable=line-too-long +from nemo.collections.common.video_tokenizers.cosmos_tokenizer import CausalVideoTokenizer +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy +from nemo.core.config import hydra_runner + +""" +This is the script to run multimodal autoregresssive text generation. + +Make sure you install tiktoken==0.6.0 + +Usage: + Assume the model has TP=1, PP=1 in the following use cases. + a. run greedy inference from a nemo file: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + captions=[caption1,caption2] + + b. run greedy inference from a PTL checkpoint file: + python megatron_mm_autoregresssive_eval.py \ + checkpoint_dir=PATH_TO_CHECKPOINT_FILE \ + checkpoint_name=CHECKPOINT_FILE_NAME \ + hparams_file=HPARAMS_FILE \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + captions=[caption1,caption2] + + c. run top_p inference from a nemo file: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=False \ + inference.top_k=0 \ + inference.top_p=0.9 \ + inference.repetition_penalty=1.2 \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + captions=[caption1,caption2] + + d. If you don't need to generate tokens and need model to compute logprobs: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.compute_logprob=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + captions=[caption1,caption2] +""" + + +def to_img(tokens_string, image_tokenizer): + """Converts visual tokens to images + + Given input visual tokens, we extract the indices, pass it to the decoder to get the image + """ + visual_token_pattern = r"<\|visual token (\d+)\|>" + visual_tokens = [int(match) for match in re.findall(visual_token_pattern, tokens_string)] + # We assume image is square. So if 64 tokensa are present, we reshape it to 8x8 and then pass it to decoder + dim = int(math.sqrt(len(visual_tokens))) + visual_tokens_tensor = torch.tensor(visual_tokens[: dim * dim]) + # Decoder accepts input of the following format [bs, channel_dim, h, w] + visual_tokens_tensor_reshaped = visual_tokens_tensor.reshape((dim, dim)).unsqueeze(0).unsqueeze(0) + visual_tokens_final = visual_tokens_tensor_reshaped.to(image_tokenizer._device) + img = image_tokenizer.decode(visual_tokens_final) + + # Convert from bf16 to 16 and to format [channel_dim, h, w] + image = torchvision.transforms.functional.to_pil_image(img.float().squeeze()) + return image + + +def load_prompts(cfg): + """Function to return the prompts passed into the model""" + prompts = [] + for caption in cfg.captions: + prompt = f'You are a helpful assistant. Draw a picture for the caption given by the user. USER: {caption}. ASSISTANT: ' + prompts.append(prompt) + return prompts + + +if not torch.cuda.is_available(): + raise EnvironmentError("GPU is needed for the inference") + + +@hydra_runner(config_path="conf", config_name="megatron_mm_ar_inference_image_generation") +def main(cfg) -> None: + """Main function""" + + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + + image_tokenizer = CausalVideoTokenizer.from_pretrained( + tokenizer_type=cfg.image_encoder, load_encoder=False, load_decoder=True, load_full_model=False + ) + + model = load_model_from_config(trainer, cfg) + model.freeze() + + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + length_params: LengthParam = { + "max_length": cfg.inference.tokens_to_generate, + "min_length": cfg.inference.min_tokens_to_generate, + } + + sampling_params: SamplingParam = { + "use_greedy": cfg.inference.greedy, + "temperature": cfg.inference.temperature, + "top_k": cfg.inference.top_k, + "top_p": cfg.inference.top_p, + "repetition_penalty": cfg.inference.repetition_penalty, + "add_BOS": cfg.inference.add_BOS, + "all_probs": cfg.inference.all_probs, + "compute_logprob": cfg.inference.compute_logprob, + "end_strings": cfg.inference.end_strings, + } + + prompts = [] + with torch.no_grad(): + prompts = load_prompts(cfg) + + fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True) + if fp8_enabled and len(prompts) > 0: + padded_len = round_to_mult(len(prompts), 8) + nb_paddings = padded_len - len(prompts) + if nb_paddings > 0: + nb_paddings += [''] * nb_paddings + + # First method of running text generation, call model.generate method + response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params) + + if fp8_enabled: + response = remove_padded_prompts(response, nb_paddings) + + output_tokens_strings = response['sentences'] + for idx, output_token_string in enumerate(output_tokens_strings): + image = to_img(output_token_string, image_tokenizer) + image.save(os.path.join(cfg.images_output_path, f'{idx}.jpg')) + + print(f'Images saved to {cfg.images_output_path}') + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py b/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py new file mode 100644 index 000000000000..4aea4d9898ae --- /dev/null +++ b/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime + +import torch +import torchvision +from examples.nlp.language_modeling.megatron_gpt_eval import ( + RequestDataSet, + load_model_from_config, + remove_padded_prompts, + round_to_mult, +) +from omegaconf import OmegaConf +from PIL import Image +from pytorch_lightning.trainer.trainer import Trainer +from torch.utils.data import DataLoader +from transformers import AutoModel, AutoTokenizer + +# pylint: disable=line-too-long +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy +from nemo.core.config import hydra_runner + +""" +This is the script to run multimodal autoregresssive text generation. + +Make sure you install tiktoken==0.6.0 + +Usage: + Assume the model has TP=1, PP=1 in the following use cases. + a. run greedy inference from a nemo file: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + images_path=[image_path1,image_path2] + + b. run greedy inference from a PTL checkpoint file: + python megatron_mm_autoregresssive_eval.py \ + checkpoint_dir=PATH_TO_CHECKPOINT_FILE \ + checkpoint_name=CHECKPOINT_FILE_NAME \ + hparams_file=HPARAMS_FILE \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + images_path=[image_path1,image_path2] + + c. run top_p inference from a nemo file: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=False \ + inference.top_k=0 \ + inference.top_p=0.9 \ + inference.repetition_penalty=1.2 \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + images_path=[image_path1,image_path2] + + d. If you don't need to generate tokens and need model to compute logprobs: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.compute_logprob=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + images_path=[image_path1,image_path2] +""" + +EMU_HUB = "BAAI/Emu3-Gen" +VQ_HUB = "BAAI/Emu3-VisionTokenizer" + + +def to_imgstr(image_tokens, tokenizer): + """Convert integer image tokens to visual tokens string""" + image_tokens = image_tokens.cpu().numpy().tolist() + image_token_str = [ + ['<|visual token {token_id:0>6d}|>'.format(token_id=token_id) for token_id in token_row] + for token_row in image_tokens + ] + image_row_str = ["".join(token_row) for token_row in image_token_str] + imgstr = tokenizer.eol_token.join(image_row_str) + return imgstr + + +def load_prompts(cfg, image_tokenizer, tokenizer): + """Function to generate prompts + + The prompts generated here are fed to the model. + """ + prompts = [] + text = "Please describe the image" + for image_path in cfg.images_path: + image = Image.open(image_path) + image_tensor = torchvision.transforms.functional.pil_to_tensor(image).unsqueeze(0) + image_tokens = image_tokenizer.encode(image_tensor.to(image_tokenizer.device, image_tokenizer.dtype)) + bs, h, w = image_tokens.shape + imgstr = to_imgstr(image_tokens[0], tokenizer=tokenizer) + image_prompt = ( + tokenizer.boi_token + + f'{h}*{w}' + + tokenizer.img_token + + imgstr + + tokenizer.eol_token + + tokenizer.eof_token + + tokenizer.eoi_token + ) + prompt = f'{tokenizer.bos_token}You are a helpful assistant. USER: {image_prompt}{text}. ASSISTANT:' + prompts.append(prompt) + return prompts + + +if not torch.cuda.is_available(): + raise EnvironmentError("GPU is needed for the inference") + + +@hydra_runner(config_path="conf", config_name="megatron_mm_ar_inference_vision_understanding") +def main(cfg) -> None: + """Main function""" + + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + + tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True) + image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda", trust_remote_code=True).eval() + + model = load_model_from_config(trainer, cfg) + model.freeze() + + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + length_params: LengthParam = { + "max_length": cfg.inference.tokens_to_generate, + "min_length": cfg.inference.min_tokens_to_generate, + } + + sampling_params: SamplingParam = { + "use_greedy": cfg.inference.greedy, + "temperature": cfg.inference.temperature, + "top_k": cfg.inference.top_k, + "top_p": cfg.inference.top_p, + "repetition_penalty": cfg.inference.repetition_penalty, + "add_BOS": cfg.inference.add_BOS, + "all_probs": cfg.inference.all_probs, + "compute_logprob": cfg.inference.compute_logprob, + "end_strings": cfg.inference.end_strings, + } + + prompts = [] + with torch.no_grad(): + prompts = load_prompts(cfg, image_tokenizer, tokenizer) + + fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True) + if fp8_enabled and len(prompts) > 0: + padded_len = round_to_mult(len(prompts), 8) + nb_paddings = padded_len - len(prompts) + if nb_paddings > 0: + nb_paddings += [''] * nb_paddings + + # First method of running text generation, call model.generate method + response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params) + + if fp8_enabled: + response = remove_padded_prompts(response, nb_paddings) + print("***************************") + print(response) + print("***************************") + + # Second method of running text generation, call trainer.predict [recommended] + bs = 8 if fp8_enabled else 2 + ds = RequestDataSet(prompts) + request_dl = DataLoader(dataset=ds, batch_size=bs) + config = OmegaConf.to_container(cfg.inference) + model.set_inference_config(config) + response = trainer.predict(model, request_dl) + + if fp8_enabled: + response[-1] = remove_padded_prompts(response[-1], nb_paddings) + print("***************************") + print(response) + print("***************************") + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/nlp/dialogue/dialogue.py b/examples/nlp/dialogue/dialogue.py index 578895a2ad43..3f4c5581eb5a 100644 --- a/examples/nlp/dialogue/dialogue.py +++ b/examples/nlp/dialogue/dialogue.py @@ -42,7 +42,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.dialogue.dialogue_gpt_classification_model import DialogueGPTClassificationModel diff --git a/examples/nlp/duplex_text_normalization/helpers.py b/examples/nlp/duplex_text_normalization/helpers.py index 6c1cfe37b90d..d9b8780fd787 100644 --- a/examples/nlp/duplex_text_normalization/helpers.py +++ b/examples/nlp/duplex_text_normalization/helpers.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig from nemo.collections.nlp.data.text_normalization import constants @@ -29,7 +29,7 @@ def instantiate_model_and_trainer(cfg: DictConfig, model_name: str, do_training: bool): - """ Function for instantiating a model and a trainer + """Function for instantiating a model and a trainer Args: cfg: The config used to instantiate the model and the trainer. model_name: A str indicates whether the model to be instantiated is a tagger or a decoder (i.e., model_name should be either TAGGER_MODEL or DECODER_MODEL). diff --git a/examples/nlp/entity_linking/self_alignment_pretraining.py b/examples/nlp/entity_linking/self_alignment_pretraining.py index a1ac1ac327cb..58b20f384d04 100644 --- a/examples/nlp/entity_linking/self_alignment_pretraining.py +++ b/examples/nlp/entity_linking/self_alignment_pretraining.py @@ -16,8 +16,8 @@ # Please see tutorial at Nemo/tutorials/nlp/Entity_Linking_Medical.ipynb for # more information on entity linking and self alignment pretraining. +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.models import EntityLinkingModel from nemo.core.config import hydra_runner diff --git a/examples/nlp/glue_benchmark/glue_benchmark.py b/examples/nlp/glue_benchmark/glue_benchmark.py index 3cb5f8e4af3e..28efb9520fbd 100644 --- a/examples/nlp/glue_benchmark/glue_benchmark.py +++ b/examples/nlp/glue_benchmark/glue_benchmark.py @@ -35,7 +35,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import GLUEModel diff --git a/examples/nlp/information_retrieval/bert_dpr.py b/examples/nlp/information_retrieval/bert_dpr.py index 2d9cd962ff34..4fc791da04fd 100644 --- a/examples/nlp/information_retrieval/bert_dpr.py +++ b/examples/nlp/information_retrieval/bert_dpr.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import BertDPRModel diff --git a/examples/nlp/information_retrieval/bert_joint_ir.py b/examples/nlp/information_retrieval/bert_joint_ir.py index 1bb164e580d1..f95cdd04e036 100644 --- a/examples/nlp/information_retrieval/bert_joint_ir.py +++ b/examples/nlp/information_retrieval/bert_joint_ir.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import BertJointIRModel diff --git a/examples/nlp/information_retrieval/megatron_gpt_embedding_finetuning.py b/examples/nlp/information_retrieval/megatron_gpt_embedding_finetuning.py index e1fe28cc892f..9cb5cb5d3d19 100644 --- a/examples/nlp/information_retrieval/megatron_gpt_embedding_finetuning.py +++ b/examples/nlp/information_retrieval/megatron_gpt_embedding_finetuning.py @@ -15,8 +15,8 @@ from collections.abc import MutableMapping import torch.multiprocessing as mp +from lightning.pytorch.loggers import WandbLogger from omegaconf.omegaconf import OmegaConf -from pytorch_lightning.loggers import WandbLogger from nemo.collections.nlp.models.information_retrieval.megatron_gpt_embedding_model import MegatronGPTEmbeddingModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder diff --git a/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py index cf65840bb843..be89e5bf5c43 100644 --- a/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py +++ b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py @@ -15,8 +15,8 @@ from collections.abc import MutableMapping import torch.multiprocessing as mp +from lightning.pytorch.loggers import WandbLogger from omegaconf.omegaconf import OmegaConf -from pytorch_lightning.loggers import WandbLogger from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder diff --git a/examples/nlp/intent_slot_classification/intent_slot_classification.py b/examples/nlp/intent_slot_classification/intent_slot_classification.py index a112ea7785f5..2025f48f330f 100644 --- a/examples/nlp/intent_slot_classification/intent_slot_classification.py +++ b/examples/nlp/intent_slot_classification/intent_slot_classification.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import IntentSlotClassificationModel diff --git a/examples/nlp/intent_slot_classification/multi_label_intent_slot_classification.py b/examples/nlp/intent_slot_classification/multi_label_intent_slot_classification.py index 2441885e2ed2..232aa7d4d230 100644 --- a/examples/nlp/intent_slot_classification/multi_label_intent_slot_classification.py +++ b/examples/nlp/intent_slot_classification/multi_label_intent_slot_classification.py @@ -27,7 +27,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import MultiLabelIntentSlotClassificationModel diff --git a/examples/nlp/language_modeling/bert_pretraining.py b/examples/nlp/language_modeling/bert_pretraining.py index 75d0a1072e69..7cff43f7fc73 100644 --- a/examples/nlp/language_modeling/bert_pretraining.py +++ b/examples/nlp/language_modeling/bert_pretraining.py @@ -13,9 +13,9 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.strategies import DDPStrategy from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.strategies import DDPStrategy from nemo.collections.nlp.models.language_modeling import BERTLMModel from nemo.core.config import hydra_runner diff --git a/examples/nlp/language_modeling/mamba_change_num_partition.py b/examples/nlp/language_modeling/mamba_change_num_partition.py index ced2b43cd312..349543de8e59 100644 --- a/examples/nlp/language_modeling/mamba_change_num_partition.py +++ b/examples/nlp/language_modeling/mamba_change_num_partition.py @@ -19,8 +19,8 @@ from argparse import ArgumentParser import torch +from lightning.pytorch import Trainer from omegaconf import open_dict -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel from nemo.collections.nlp.parts.nlp_overrides import ( diff --git a/examples/nlp/language_modeling/megatron_bart_pretraining.py b/examples/nlp/language_modeling/megatron_bart_pretraining.py index e45b5e04ca45..a6dd6f183d72 100644 --- a/examples/nlp/language_modeling/megatron_bart_pretraining.py +++ b/examples/nlp/language_modeling/megatron_bart_pretraining.py @@ -13,11 +13,11 @@ # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.parts.nlp_overrides import ( @@ -48,7 +48,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_change_num_partitions.py b/examples/nlp/language_modeling/megatron_change_num_partitions.py index c035346e3bf1..49d1ef0dcb57 100644 --- a/examples/nlp/language_modeling/megatron_change_num_partitions.py +++ b/examples/nlp/language_modeling/megatron_change_num_partitions.py @@ -21,8 +21,8 @@ import torch import torch.nn as nn +from lightning.pytorch import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.nlp.parts.nlp_overrides import ( NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, @@ -922,7 +922,7 @@ def main(): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=tmp_cfg.get('native_amp_init_scale', 2 ** 32), + init_scale=tmp_cfg.get('native_amp_init_scale', 2**32), growth_interval=tmp_cfg.get('native_amp_growth_interval', 1000), hysteresis=tmp_cfg.get('hysteresis', 2), ) @@ -943,7 +943,10 @@ def main(): if tp_size < 0 or pp_size < 0: logging.info(f"Loading model config from {args.model_file} to get TP and PP size") model_config_internal = cls.restore_from( - restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu"), return_config=True, + restore_path=args.model_file, + trainer=trainer, + map_location=torch.device("cpu"), + return_config=True, ) tp_size = model_config_internal.get('tensor_model_parallel_size', 1) @@ -1137,7 +1140,9 @@ def main(): else: model = cls.load_from_checkpoint( - checkpoint_path=checkpoint_path, trainer=trainer, map_location=torch.device("cpu"), + checkpoint_path=checkpoint_path, + trainer=trainer, + map_location=torch.device("cpu"), ) model.freeze() diff --git a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py index c81119489582..b46f8f459ff0 100644 --- a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py @@ -32,10 +32,10 @@ import torch from genericpath import isdir +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from megatron.core import parallel_state from omegaconf import OmegaConf, open_dict -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel diff --git a/examples/nlp/language_modeling/megatron_export.py b/examples/nlp/language_modeling/megatron_export.py index bf9157884bfc..b511a415d9b1 100644 --- a/examples/nlp/language_modeling/megatron_export.py +++ b/examples/nlp/language_modeling/megatron_export.py @@ -28,8 +28,8 @@ import os +from lightning.pytorch import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel diff --git a/examples/nlp/language_modeling/megatron_gpt_distillation.py b/examples/nlp/language_modeling/megatron_gpt_distillation.py index dc8614be23b2..c00470c5c81e 100644 --- a/examples/nlp/language_modeling/megatron_gpt_distillation.py +++ b/examples/nlp/language_modeling/megatron_gpt_distillation.py @@ -19,8 +19,8 @@ import modelopt.torch.distill as mtd import modelopt.torch.opt as mto import torch.multiprocessing as mp +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer try: from megatron.core import parallel_state, tensor_parallel diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index b9b0d2973094..4dbbee78e898 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -20,8 +20,8 @@ from functools import partial import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py b/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py index 988a5f8588ff..ceb32d75f495 100644 --- a/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py @@ -16,6 +16,7 @@ import os from argparse import Namespace +from lightning.pytorch.trainer.trainer import Trainer from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.inference_model_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper @@ -23,7 +24,6 @@ SimpleTextGenerationController, ) from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel diff --git a/examples/nlp/language_modeling/megatron_gpt_prune.py b/examples/nlp/language_modeling/megatron_gpt_prune.py index de12b861a1c0..44992873f362 100644 --- a/examples/nlp/language_modeling/megatron_gpt_prune.py +++ b/examples/nlp/language_modeling/megatron_gpt_prune.py @@ -16,8 +16,8 @@ import torch import torch.multiprocessing as mp from datasets import load_dataset +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from tqdm import tqdm from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/nlp/language_modeling/megatron_gpt_ptq.py b/examples/nlp/language_modeling/megatron_gpt_ptq.py index e41becc2d8e0..0ac0822c5fbe 100644 --- a/examples/nlp/language_modeling/megatron_gpt_ptq.py +++ b/examples/nlp/language_modeling/megatron_gpt_ptq.py @@ -15,8 +15,8 @@ import torch import torch.multiprocessing as mp from datasets import load_dataset +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from tqdm import tqdm from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/nlp/language_modeling/megatron_gpt_test.py b/examples/nlp/language_modeling/megatron_gpt_test.py index 62a1d40dbaed..03bc6735e891 100644 --- a/examples/nlp/language_modeling/megatron_gpt_test.py +++ b/examples/nlp/language_modeling/megatron_gpt_test.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank @@ -38,7 +38,7 @@ def main(cfg) -> None: trainer = Trainer( plugins=[ NLPMixedPrecisionPlugin( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), ), ], @@ -46,7 +46,13 @@ def main(cfg) -> None: **cfg.trainer, ) elif cfg.trainer.precision in ['bf16', 'bf16-mixed']: - trainer = Trainer(plugins=[NLPNativeBfloat16PrecisionPlugin(),], strategy=NLPDDPStrategy(), **cfg.trainer,) + trainer = Trainer( + plugins=[ + NLPNativeBfloat16PrecisionPlugin(), + ], + strategy=NLPDDPStrategy(), + **cfg.trainer, + ) else: trainer = Trainer(plugins=[NLPPrecisionPlugin()], strategy=NLPDDPStrategy(), **cfg.trainer) @@ -55,7 +61,9 @@ def main(cfg) -> None: app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size) model = MegatronGPTModel.restore_from( - cfg.restore_from_path, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector(), + cfg.restore_from_path, + trainer=trainer, + save_restore_connector=NLPSaveRestoreConnector(), ) # Note: most nemo models must have the data paths configured before instantiating the model diff --git a/examples/nlp/language_modeling/megatron_gpt_validate.py b/examples/nlp/language_modeling/megatron_gpt_validate.py index b5a61e627a14..fa0abb89421c 100644 --- a/examples/nlp/language_modeling/megatron_gpt_validate.py +++ b/examples/nlp/language_modeling/megatron_gpt_validate.py @@ -15,8 +15,8 @@ import os import tempfile +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel @@ -140,7 +140,9 @@ def main(cfg) -> None: with tempfile.NamedTemporaryFile(suffix='.yaml') as f: OmegaConf.save(config=pretrained_cfg, f=f.name) model = MegatronGPTModel.load_from_checkpoint( - checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name, + checkpoint_path=checkpoint_path, + trainer=trainer, + hparams_file=f.name, ) else: raise ValueError("need at least a nemo file or checkpoint dir") diff --git a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py index 72252a03d5be..64ba2a51bb71 100644 --- a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py @@ -42,12 +42,12 @@ from typing import Any, Optional import torch -from lightning_fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml +from lightning.pytorch.trainer.trainer import Trainer +from lightning.pytorch.utilities.migration import pl_legacy_patch from megatron.core import parallel_state -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml -from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.utilities.migration import pl_legacy_patch from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/nlp/language_modeling/megatron_mamba_eval.py b/examples/nlp/language_modeling/megatron_mamba_eval.py index ed12e4b904ac..ba000e6bef63 100644 --- a/examples/nlp/language_modeling/megatron_mamba_eval.py +++ b/examples/nlp/language_modeling/megatron_mamba_eval.py @@ -20,8 +20,8 @@ from functools import partial import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel diff --git a/examples/nlp/language_modeling/megatron_retro_cal_shape.py b/examples/nlp/language_modeling/megatron_retro_cal_shape.py index a57a927d2a36..f790d9471964 100644 --- a/examples/nlp/language_modeling/megatron_retro_cal_shape.py +++ b/examples/nlp/language_modeling/megatron_retro_cal_shape.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.megatron.mup.shape import make_base_shapes @@ -46,7 +46,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_retro_eval.py b/examples/nlp/language_modeling/megatron_retro_eval.py index 89e3fe9c3ddb..ac946b2adf42 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval.py +++ b/examples/nlp/language_modeling/megatron_retro_eval.py @@ -16,8 +16,8 @@ import os import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel @@ -60,7 +60,9 @@ def __init__(self, sentences, neighbors): self.sentences = sentences self.neighbors = neighbors - def __len__(self,): + def __len__( + self, + ): return len(self.sentences) def __getitem__(self, idx): diff --git a/examples/nlp/language_modeling/megatron_retro_eval_legacy.py b/examples/nlp/language_modeling/megatron_retro_eval_legacy.py index 69222acedd34..c51a8f536cc1 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval_legacy.py +++ b/examples/nlp/language_modeling/megatron_retro_eval_legacy.py @@ -15,8 +15,8 @@ import os from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel @@ -69,7 +69,10 @@ def main(cfg) -> None: save_restore_connector.model_extracted_dir = model_path model_cfg = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, + model_path, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, ) with open_dict(model_cfg): @@ -89,7 +92,10 @@ def main(cfg) -> None: cfg.pipeline_model_parallel_split_rank = model_cfg.get('pipeline_model_parallel_split_rank', 0) model = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, + model_path, + trainer=trainer, + save_restore_connector=save_restore_connector, + override_config_path=model_cfg, ) length_params: LengthParam = { diff --git a/examples/nlp/language_modeling/megatron_retro_fine_tune.py b/examples/nlp/language_modeling/megatron_retro_fine_tune.py index 3fcaec156d9c..153a4b581135 100644 --- a/examples/nlp/language_modeling/megatron_retro_fine_tune.py +++ b/examples/nlp/language_modeling/megatron_retro_fine_tune.py @@ -15,12 +15,12 @@ import datetime import os +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks.timer import Timer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_retro_fine_tune_model import MegatronRetroFinetuneModel from nemo.collections.nlp.parts.nlp_overrides import ( @@ -87,7 +87,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) @@ -118,7 +118,9 @@ def main(cfg) -> None: # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): - trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) + trainer.callbacks[idx] = StatelessTimer( + cfg.trainer.max_time, + ) # load existing or init new soft prompt GPT model if cfg.model.get("restore_path", None): diff --git a/examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py b/examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py index af6e22035def..775b75680ee9 100644 --- a/examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py +++ b/examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.megatron.mup.optim import MuAdam, MuAdamW @@ -52,7 +52,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py b/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py index 4653222b3438..298deafabc1c 100644 --- a/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py +++ b/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py @@ -14,11 +14,11 @@ import os +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo @@ -51,7 +51,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_retro_qatask_eval.py b/examples/nlp/language_modeling/megatron_retro_qatask_eval.py index b99bcafbab02..4e47157d5150 100644 --- a/examples/nlp/language_modeling/megatron_retro_qatask_eval.py +++ b/examples/nlp/language_modeling/megatron_retro_qatask_eval.py @@ -17,8 +17,8 @@ import os import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.data.question_answering.input_example.qa_input_example import QAExample @@ -63,7 +63,9 @@ def __init__(self, sentences, neighbors): self.sentences = sentences self.neighbors = neighbors - def __len__(self,): + def __len__( + self, + ): return len(self.sentences) def __getitem__(self, idx): diff --git a/examples/nlp/language_modeling/megatron_t5_eval.py b/examples/nlp/language_modeling/megatron_t5_eval.py index 0b6ea54b6b99..57b48134101f 100644 --- a/examples/nlp/language_modeling/megatron_t5_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_eval.py @@ -17,8 +17,8 @@ from argparse import ArgumentParser import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader from nemo.collections.nlp.data.language_modeling.megatron.request_dataset import T5RequestDataset @@ -40,13 +40,22 @@ def main(): "--tokens_to_generate", type=int, default="16", required=False, help="How many tokens to add to prompt" ) parser.add_argument( - "--tensor_model_parallel_size", type=int, default=-1, required=False, + "--tensor_model_parallel_size", + type=int, + default=-1, + required=False, ) parser.add_argument( - "--pipeline_model_parallel_size", type=int, default=-1, required=False, + "--pipeline_model_parallel_size", + type=int, + default=-1, + required=False, ) parser.add_argument( - "--pipeline_model_parallel_split_rank", type=int, default=-1, required=False, + "--pipeline_model_parallel_split_rank", + type=int, + default=-1, + required=False, ) parser.add_argument("--precision", default="16", type=str, help="PyTorch Lightning Trainer precision flag") parser.add_argument("--decoder_starts_with_pad", action="store_true", help="Decoder starts with pad token") diff --git a/examples/nlp/language_modeling/megatron_t5_lm_adaptation_finetune.py b/examples/nlp/language_modeling/megatron_t5_lm_adaptation_finetune.py index 9e392d913171..4137213023ee 100644 --- a/examples/nlp/language_modeling/megatron_t5_lm_adaptation_finetune.py +++ b/examples/nlp/language_modeling/megatron_t5_lm_adaptation_finetune.py @@ -13,11 +13,11 @@ # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model from nemo.collections.nlp.parts.nlp_overrides import ( @@ -49,7 +49,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py b/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py index ba8ea6492da3..ae6e1744395d 100644 --- a/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin from megatron_t5_seq2seq_finetune import load_from_checkpoint_dir, load_from_nemo, validate_checkpoint_loading_args from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model @@ -82,7 +82,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py b/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py index 2409e99ad951..5f63289be27a 100644 --- a/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py +++ b/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py @@ -16,10 +16,10 @@ import tempfile import torch.multiprocessing as mp +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model diff --git a/examples/nlp/language_modeling/transformer_lm.py b/examples/nlp/language_modeling/transformer_lm.py index caaa0e0d2935..3e97e28bb35e 100644 --- a/examples/nlp/language_modeling/transformer_lm.py +++ b/examples/nlp/language_modeling/transformer_lm.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.language_modeling import TransformerLMModel diff --git a/examples/nlp/language_modeling/upcycle_dense_to_moe.py b/examples/nlp/language_modeling/upcycle_dense_to_moe.py index a1f4b6000b6f..f4a5fc017d97 100644 --- a/examples/nlp/language_modeling/upcycle_dense_to_moe.py +++ b/examples/nlp/language_modeling/upcycle_dense_to_moe.py @@ -26,7 +26,7 @@ import torch import torch.nn -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector diff --git a/examples/nlp/machine_translation/enc_dec_nmt-bottleneck.py b/examples/nlp/machine_translation/enc_dec_nmt-bottleneck.py index b1743e03188e..58c948f11458 100644 --- a/examples/nlp/machine_translation/enc_dec_nmt-bottleneck.py +++ b/examples/nlp/machine_translation/enc_dec_nmt-bottleneck.py @@ -15,8 +15,8 @@ from dataclasses import dataclass from typing import Optional +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.data.machine_translation.preproc_mt_data import MTDataPreproc from nemo.collections.nlp.models.machine_translation.mt_enc_dec_bottleneck_model import MTBottleneckModel @@ -29,7 +29,6 @@ from nemo.utils.config_utils import update_model_config from nemo.utils.exp_manager import ExpManagerConfig, exp_manager - """ Usage: 1. If you need to start docker and install NeMo, otherwise skip this step: diff --git a/examples/nlp/machine_translation/enc_dec_nmt.py b/examples/nlp/machine_translation/enc_dec_nmt.py index 57b9f84c39ce..b901ba28a4db 100644 --- a/examples/nlp/machine_translation/enc_dec_nmt.py +++ b/examples/nlp/machine_translation/enc_dec_nmt.py @@ -15,8 +15,8 @@ from dataclasses import dataclass from typing import Optional +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.data.machine_translation.preproc_mt_data import MTDataPreproc from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import MTEncDecModelConfig @@ -29,7 +29,6 @@ from nemo.utils.config_utils import update_model_config from nemo.utils.exp_manager import ExpManagerConfig, exp_manager - """ Usage: 1. If you need to start docker and install NeMo, otherwise skip this step: diff --git a/examples/nlp/machine_translation/enc_dec_nmt_finetune.py b/examples/nlp/machine_translation/enc_dec_nmt_finetune.py index 16a635d09dee..688461a7b491 100644 --- a/examples/nlp/machine_translation/enc_dec_nmt_finetune.py +++ b/examples/nlp/machine_translation/enc_dec_nmt_finetune.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from typing import Optional +from lightning.pytorch import Trainer from omegaconf import OmegaConf from omegaconf.omegaconf import MISSING -from pytorch_lightning import Trainer from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import MTEncDecModelConfig from nemo.collections.nlp.models.machine_translation.mt_enc_dec_model import MTEncDecModel @@ -29,7 +29,6 @@ from nemo.utils.config_utils import update_model_config from nemo.utils.exp_manager import ExpManagerConfig, exp_manager - """ Usage: python enc_dec_nmt_finetune.py \ diff --git a/examples/nlp/machine_translation/megatron_nmt_training.py b/examples/nlp/machine_translation/megatron_nmt_training.py index 7946500f92e9..5ff70a7a863c 100644 --- a/examples/nlp/machine_translation/megatron_nmt_training.py +++ b/examples/nlp/machine_translation/megatron_nmt_training.py @@ -14,11 +14,11 @@ import torch.multiprocessing as mp +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model @@ -53,7 +53,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py b/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py index fcf1fb8d1796..349155101a5d 100644 --- a/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py +++ b/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py @@ -24,8 +24,8 @@ import os +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel diff --git a/examples/nlp/question_answering/question_answering.py b/examples/nlp/question_answering/question_answering.py index fcde03582e5c..37bd43a4b0fb 100644 --- a/examples/nlp/question_answering/question_answering.py +++ b/examples/nlp/question_answering/question_answering.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.question_answering.qa_bert_model import BERTQAModel diff --git a/examples/nlp/spellchecking_asr_customization/helpers.py b/examples/nlp/spellchecking_asr_customization/helpers.py index 2db11b0e7d96..8e3957d34cc1 100644 --- a/examples/nlp/spellchecking_asr_customization/helpers.py +++ b/examples/nlp/spellchecking_asr_customization/helpers.py @@ -16,7 +16,7 @@ import os from typing import Tuple -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig from nemo.collections.nlp.models import SpellcheckingAsrCustomizationModel @@ -32,7 +32,7 @@ def instantiate_model_and_trainer( cfg: DictConfig, model_name: str, do_training: bool ) -> Tuple[pl.Trainer, SpellcheckingAsrCustomizationModel]: - """ Function for instantiating a model and a trainer + """Function for instantiating a model and a trainer Args: cfg: The config used to instantiate the model and the trainer. model_name: A str indicates the model direction, currently only 'itn'. diff --git a/examples/nlp/text2sparql/evaluate_text2sparql.py b/examples/nlp/text2sparql/evaluate_text2sparql.py index 52baa2a7e78c..774ced98e8ec 100644 --- a/examples/nlp/text2sparql/evaluate_text2sparql.py +++ b/examples/nlp/text2sparql/evaluate_text2sparql.py @@ -39,7 +39,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text2sparql import Text2SparqlModel diff --git a/examples/nlp/text2sparql/text2sparql.py b/examples/nlp/text2sparql/text2sparql.py index 1353a3967735..d70a7e616950 100644 --- a/examples/nlp/text2sparql/text2sparql.py +++ b/examples/nlp/text2sparql/text2sparql.py @@ -88,7 +88,7 @@ exp_manager.exp_dir=./NeMo_logs """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text2sparql import Text2SparqlModel diff --git a/examples/nlp/text_classification/model_parallel_text_classification_evaluation.py b/examples/nlp/text_classification/model_parallel_text_classification_evaluation.py index ab3322f552c1..cf9b6d8dd2e4 100644 --- a/examples/nlp/text_classification/model_parallel_text_classification_evaluation.py +++ b/examples/nlp/text_classification/model_parallel_text_classification_evaluation.py @@ -15,7 +15,7 @@ """ This script runs model parallel text classification evaluation. """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text_classification import TextClassificationModel diff --git a/examples/nlp/text_classification/text_classification_with_bert.py b/examples/nlp/text_classification/text_classification_with_bert.py index 01e8fae9bba5..a6c84b4e337a 100644 --- a/examples/nlp/text_classification/text_classification_with_bert.py +++ b/examples/nlp/text_classification/text_classification_with_bert.py @@ -95,7 +95,7 @@ eval_model.set_trainer(eval_trainer) eval_trainer.test(model=eval_model, verbose=False) """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text_classification import TextClassificationModel diff --git a/examples/nlp/text_normalization_as_tagging/helpers.py b/examples/nlp/text_normalization_as_tagging/helpers.py index 347b05b25fba..de74794f8f40 100644 --- a/examples/nlp/text_normalization_as_tagging/helpers.py +++ b/examples/nlp/text_normalization_as_tagging/helpers.py @@ -16,7 +16,7 @@ import os from typing import Tuple -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig from nemo.collections.nlp.models import ThutmoseTaggerModel @@ -31,7 +31,7 @@ def instantiate_model_and_trainer( cfg: DictConfig, model_name: str, do_training: bool ) -> Tuple[pl.Trainer, ThutmoseTaggerModel]: - """ Function for instantiating a model and a trainer + """Function for instantiating a model and a trainer Args: cfg: The config used to instantiate the model and the trainer. model_name: A str indicates the model direction, currently only 'itn'. diff --git a/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py b/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py index 149a9a4515e2..508e434bb598 100644 --- a/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py +++ b/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import DictConfig, OmegaConf diff --git a/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py b/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py index e983540a68b2..b16e1ecd0bdc 100644 --- a/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py +++ b/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import DictConfig, OmegaConf diff --git a/examples/nlp/token_classification/token_classification_evaluate.py b/examples/nlp/token_classification/token_classification_evaluate.py index b69212f59de4..764aa90c8593 100644 --- a/examples/nlp/token_classification/token_classification_evaluate.py +++ b/examples/nlp/token_classification/token_classification_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig from nemo.collections.nlp.models import TokenClassificationModel diff --git a/examples/nlp/token_classification/token_classification_train.py b/examples/nlp/token_classification/token_classification_train.py index 56c1487cf9c5..536327aff6da 100644 --- a/examples/nlp/token_classification/token_classification_train.py +++ b/examples/nlp/token_classification/token_classification_train.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import TokenClassificationModel diff --git a/examples/nlp/zero_shot_intent_recognition/zero_shot_intent_train.py b/examples/nlp/zero_shot_intent_recognition/zero_shot_intent_train.py index 5b91049e965d..4dbbf01c935e 100644 --- a/examples/nlp/zero_shot_intent_recognition/zero_shot_intent_train.py +++ b/examples/nlp/zero_shot_intent_recognition/zero_shot_intent_train.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import ZeroShotIntentModel diff --git a/examples/slu/speech_intent_slot/eval_utils/inference.py b/examples/slu/speech_intent_slot/eval_utils/inference.py index 9bd76c76822d..241f6463ed76 100644 --- a/examples/slu/speech_intent_slot/eval_utils/inference.py +++ b/examples/slu/speech_intent_slot/eval_utils/inference.py @@ -21,7 +21,7 @@ from pathlib import Path from typing import List, Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf from tqdm.auto import tqdm diff --git a/examples/slu/speech_intent_slot/speech_intent_slot_train.py b/examples/slu/speech_intent_slot/speech_intent_slot_train.py index a9999d4d4682..f8732ec757e1 100644 --- a/examples/slu/speech_intent_slot/speech_intent_slot_train.py +++ b/examples/slu/speech_intent_slot/speech_intent_slot_train.py @@ -66,7 +66,7 @@ from pathlib import Path -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_infer.py b/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_infer.py index 35077a5fe415..5c0f956c2e3c 100644 --- a/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_infer.py +++ b/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_infer.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.models import ClusteringDiarizer from nemo.core.config import hydra_runner diff --git a/examples/speaker_tasks/diarization/neural_diarizer/multiscale_diar_decoder.py b/examples/speaker_tasks/diarization/neural_diarizer/multiscale_diar_decoder.py index 984b5ce93464..bc1db4dc1126 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/multiscale_diar_decoder.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/multiscale_diar_decoder.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.models import EncDecDiarLabelModel from nemo.core.config import hydra_runner diff --git a/examples/speaker_tasks/recognition/speaker_identification_infer.py b/examples/speaker_tasks/recognition/speaker_identification_infer.py index 90f930fcbfa6..7075a9f1f92a 100644 --- a/examples/speaker_tasks/recognition/speaker_identification_infer.py +++ b/examples/speaker_tasks/recognition/speaker_identification_infer.py @@ -16,8 +16,8 @@ import numpy as np import torch +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataset from nemo.collections.asr.models import EncDecSpeakerLabelModel @@ -55,10 +55,18 @@ def main(cfg): speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path) enroll_embs, _, enroll_truelabels, _ = speaker_model.batch_inference( - enrollment_manifest, batch_size, sample_rate, device=device, + enrollment_manifest, + batch_size, + sample_rate, + device=device, ) - test_embs, _, _, _ = speaker_model.batch_inference(test_manifest, batch_size, sample_rate, device=device,) + test_embs, _, _, _ = speaker_model.batch_inference( + test_manifest, + batch_size, + sample_rate, + device=device, + ) # length normalize enroll_embs = enroll_embs / (np.linalg.norm(enroll_embs, ord=2, axis=-1, keepdims=True)) @@ -91,7 +99,12 @@ def main(cfg): "number of labels mis match. Make sure you trained or finetuned neural classifier with labels from enrollement manifest_filepath" ) - _, test_logits, _, _ = speaker_model.batch_inference(test_manifest, batch_size, sample_rate, device=device,) + _, test_logits, _, _ = speaker_model.batch_inference( + test_manifest, + batch_size, + sample_rate, + device=device, + ) matched_labels = test_logits.argmax(axis=-1) with open(test_manifest, 'rb') as f1, open(out_manifest, 'w', encoding='utf-8') as f2: diff --git a/examples/speaker_tasks/recognition/speaker_reco.py b/examples/speaker_tasks/recognition/speaker_reco.py index a8acd4de4a3f..ac5cb12ac836 100644 --- a/examples/speaker_tasks/recognition/speaker_reco.py +++ b/examples/speaker_tasks/recognition/speaker_reco.py @@ -14,10 +14,10 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.models import EncDecSpeakerLabelModel from nemo.core.config import hydra_runner diff --git a/examples/speaker_tasks/recognition/speaker_reco_finetune.py b/examples/speaker_tasks/recognition/speaker_reco_finetune.py index 884e5a60bc59..502d016a920d 100644 --- a/examples/speaker_tasks/recognition/speaker_reco_finetune.py +++ b/examples/speaker_tasks/recognition/speaker_reco_finetune.py @@ -14,10 +14,10 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.models import EncDecSpeakerLabelModel from nemo.core.config import hydra_runner diff --git a/examples/tts/aligner.py b/examples/tts/aligner.py index e32c0444ca68..939b8dbcf11f 100644 --- a/examples/tts/aligner.py +++ b/examples/tts/aligner.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import AlignerModel diff --git a/examples/tts/audio_codec.py b/examples/tts/audio_codec.py index 5fc4b6fd0afd..d875a3037ba3 100644 --- a/examples/tts/audio_codec.py +++ b/examples/tts/audio_codec.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.tts.models import AudioCodecModel diff --git a/examples/tts/fastpitch.py b/examples/tts/fastpitch.py index a8e6ecdc902d..7fd584b773e4 100644 --- a/examples/tts/fastpitch.py +++ b/examples/tts/fastpitch.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import FastPitchModel diff --git a/examples/tts/fastpitch_finetune.py b/examples/tts/fastpitch_finetune.py index 64b5e8b90625..9bdf704c514c 100644 --- a/examples/tts/fastpitch_finetune.py +++ b/examples/tts/fastpitch_finetune.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import FastPitchModel diff --git a/examples/tts/fastpitch_finetune_adapters.py b/examples/tts/fastpitch_finetune_adapters.py index 1361d63fb4cf..9b50d70ab15e 100644 --- a/examples/tts/fastpitch_finetune_adapters.py +++ b/examples/tts/fastpitch_finetune_adapters.py @@ -15,7 +15,7 @@ import os from dataclasses import is_dataclass -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf, open_dict from nemo.collections.common.callbacks import LogEpochTimeCallback diff --git a/examples/tts/fastpitch_ssl.py b/examples/tts/fastpitch_ssl.py index 1101ac1eeaf7..b92983a4bfb1 100644 --- a/examples/tts/fastpitch_ssl.py +++ b/examples/tts/fastpitch_ssl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import fastpitch_ssl, hifigan diff --git a/examples/tts/g2p/g2p_heteronym_classification_inference.py b/examples/tts/g2p/g2p_heteronym_classification_inference.py index 61262c41a340..89a563e9b683 100644 --- a/examples/tts/g2p/g2p_heteronym_classification_inference.py +++ b/examples/tts/g2p/g2p_heteronym_classification_inference.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, is_dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf @@ -56,9 +56,9 @@ class TranscriptionConfig: # path to .json manifest inference, if not provided, interactive mode will be enabled manifest: Optional[str] = None # Path to .json manifest - output_manifest: Optional[ - str - ] = "predictions.json" # Path to .json manifest to save prediction, will be saved in "pred_text" field + output_manifest: Optional[str] = ( + "predictions.json" # Path to .json manifest to save prediction, will be saved in "pred_text" field + ) grapheme_field: str = "text_graphemes" # name of the field in .json manifest for input grapheme text # mapping from wordid predicted by the model to phonemes, e.g., @@ -132,9 +132,10 @@ def main(cfg): save_errors = True correct = 0 total = 0 - with open(cfg.output_manifest, "r", encoding="utf-8") as f_preds, open( - cfg.errors_file, "w", encoding="utf-8" - ) as f_errors: + with ( + open(cfg.output_manifest, "r", encoding="utf-8") as f_preds, + open(cfg.errors_file, "w", encoding="utf-8") as f_errors, + ): for line in f_preds: line = json.loads(line) predictions = line["pred_wordid"] diff --git a/examples/tts/g2p/g2p_heteronym_classification_train_and_evaluate.py b/examples/tts/g2p/g2p_heteronym_classification_train_and_evaluate.py index 613865618501..f86a0a3934e4 100644 --- a/examples/tts/g2p/g2p_heteronym_classification_train_and_evaluate.py +++ b/examples/tts/g2p/g2p_heteronym_classification_train_and_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from nemo.collections.common.callbacks import LogEpochTimeCallback diff --git a/examples/tts/g2p/g2p_inference.py b/examples/tts/g2p/g2p_inference.py index e7bffa888653..a9da11fcffdb 100644 --- a/examples/tts/g2p/g2p_inference.py +++ b/examples/tts/g2p/g2p_inference.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, is_dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf from utils import get_metrics @@ -41,23 +41,23 @@ class TranscriptionConfig: # Required configs pretrained_model: str # Path to a .nemo file or Name of a pretrained model manifest_filepath: str # Path to .json manifest file - phoneme_field: Optional[ - str - ] = None # name of the field in manifest_filepath for ground truth phonemes, default during training "text" + phoneme_field: Optional[str] = ( + None # name of the field in manifest_filepath for ground truth phonemes, default during training "text" + ) grapheme_field: Optional[str] = "text_graphemes" # name of the field in manifest_filepath for input grapheme text # General configs - output_file: Optional[ - str - ] = None # Path to .json manifest file to save predictions, will be saved in "target_field" + output_file: Optional[str] = ( + None # Path to .json manifest file to save predictions, will be saved in "target_field" + ) pred_field: Optional[str] = "pred_text" # name of the field in the output_file to save predictions batch_size: int = 32 # Batch size to use for inference num_workers: int = 0 # Number of workers to use for DataLoader during inference # Config for heteronyms correction - pretrained_heteronyms_model: Optional[ - str - ] = None # Path to a .nemo file or a Name of a pretrained model to disambiguate heteronyms (Optional) + pretrained_heteronyms_model: Optional[str] = ( + None # Path to a .nemo file or a Name of a pretrained model to disambiguate heteronyms (Optional) + ) @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) diff --git a/examples/tts/g2p/g2p_train_and_evaluate.py b/examples/tts/g2p/g2p_train_and_evaluate.py index ff7b2b0675ea..319e1fb6a776 100644 --- a/examples/tts/g2p/g2p_train_and_evaluate.py +++ b/examples/tts/g2p/g2p_train_and_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from utils import get_model diff --git a/examples/tts/hifigan.py b/examples/tts/hifigan.py index 5c3406a2f24c..6cf5c7a5aac4 100644 --- a/examples/tts/hifigan.py +++ b/examples/tts/hifigan.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.tts.models import HifiGanModel from nemo.core.config import hydra_runner diff --git a/examples/tts/hifigan_finetune.py b/examples/tts/hifigan_finetune.py index f0e2513404fd..328e1f423903 100644 --- a/examples/tts/hifigan_finetune.py +++ b/examples/tts/hifigan_finetune.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.tts.models import HifiGanModel from nemo.core.config import hydra_runner diff --git a/examples/tts/mixer_tts.py b/examples/tts/mixer_tts.py index 61a188f53969..53f55d93bcda 100644 --- a/examples/tts/mixer_tts.py +++ b/examples/tts/mixer_tts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import MixerTTSModel diff --git a/examples/tts/radtts.py b/examples/tts/radtts.py index 09bf69a2d6e5..4b3b0e62da87 100644 --- a/examples/tts/radtts.py +++ b/examples/tts/radtts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models.radtts import RadTTSModel diff --git a/examples/tts/spectrogram_enhancer.py b/examples/tts/spectrogram_enhancer.py index 336729236d74..cd91ef3cb815 100644 --- a/examples/tts/spectrogram_enhancer.py +++ b/examples/tts/spectrogram_enhancer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.tts.models import SpectrogramEnhancerModel from nemo.core.config import hydra_runner diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml new file mode 100644 index 000000000000..8b37077bfdd5 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml @@ -0,0 +1,160 @@ +name: megatron_t5_speechllm_tts_inference +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 10000 + max_steps: -1 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 3 + gradient_clip_val: 1.0 + +exp_manager: + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 2 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 16 + micro_batch_size: 16 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 1000 # Maximum number of timesteps to run inference for + + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + language_model_path: ??? # Path to the pretrained T5 language model .nemo file, always required + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + use_flash_attention: false + lm_vocab_size: 30000 + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 1536 + sample_rate: 24000 + add_eos: true + add_bos: false + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30000 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 5e-5 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml new file mode 100644 index 000000000000..1858edf9e667 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml @@ -0,0 +1,213 @@ +name: megatron_t5_speechllm_tts_inference +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 10000 + max_steps: -1 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 3 + gradient_clip_val: 1.0 + +exp_manager: + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 2 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 16 + micro_batch_size: 16 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 1000 # Maximum number of timesteps to run inference for + + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + + frozen_model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 1536 + sample_rate: 24000 + add_eos: true + add_bos: false + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30000 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 5e-5 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml new file mode 100644 index 000000000000..8ad967d20538 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml @@ -0,0 +1,218 @@ +name: megatron_t5_speechllm +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 2000 # Maximum number of timesteps to run inference for + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: encoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + enc_output_to_layers: [[0,1,2],[3,4,5,6,7,8]] + + frozen_model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: multi_transformer + n_transformers: 2 + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 6 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: false + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + encoder_type: ${model.frozen_model.encoder.arch} + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml new file mode 100644 index 000000000000..bd31f0712fdf --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml @@ -0,0 +1,161 @@ +name: megatron_t5_speechllm_medium + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 1000000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + language_model_path: ??? # Path to the pretrained T5 language model .nemo file, always required + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + use_flash_attention: false + lm_vocab_size: 30000 + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml new file mode 100644 index 000000000000..bf3f65ff9e00 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml @@ -0,0 +1,223 @@ +name: megatron_t5_speechllm + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: encoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + enc_output_to_layers: [[0,1,2],[3,4,5,6,7,8]] + + frozen_model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: multi_transformer + n_transformers: 2 + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 6 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + encoder_type: ${model.frozen_model.encoder.arch} + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 1000 + constant_steps: 0 + min_lr: 1e-5 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml new file mode 100644 index 000000000000..d69bfb979182 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml @@ -0,0 +1,221 @@ +name: megatron_t5_speechllm + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + + frozen_model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + use_ipa: false + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 1000 + constant_steps: 0 + min_lr: 1e-5 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/tts/speechllm/megatron_t5_speechllm.py b/examples/tts/speechllm/megatron_t5_speechllm.py new file mode 100644 index 000000000000..c4ec1a77f944 --- /dev/null +++ b/examples/tts/speechllm/megatron_t5_speechllm.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_t5_speechllm_medium.yaml") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + # MegatronTrainerBuilder compat checks + if "gradient_as_bucket_view" not in cfg.model: + with open_dict(cfg): + cfg.model.gradient_as_bucket_view = False + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + # load existing or init new soft prompt T5 model + if cfg.model.get("restore_path", None) is not None: + logging.info(f"cfg.model.restore_path {cfg.model.restore_path}") + model = MegatronT5SpeechLMModel.restore_from( + cfg.model.restore_path, cfg.model, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector() + ) + else: + logging.info(f"cfg.model.restore_path is None") + model = MegatronT5SpeechLMModel(cfg.model, trainer=trainer) + model.maybe_init_from_pretrained_checkpoint(cfg=cfg) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/tts/speechllm/megatron_t5_speechllm_inference.py b/examples/tts/speechllm/megatron_t5_speechllm_inference.py new file mode 100644 index 000000000000..48d46952a993 --- /dev/null +++ b/examples/tts/speechllm/megatron_t5_speechllm_inference.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_t5_speechllm_inference.yaml") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + # MegatronTrainerBuilder compat checks + if "gradient_as_bucket_view" not in cfg.model: + with open_dict(cfg): + cfg.model.gradient_as_bucket_view = False + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + # load existing or init new soft prompt T5 model + checkpoint_path = cfg.get('checkpoint_path', None) + assert checkpoint_path is not None, "Please specify checkpoint_path in the config file" + model = MegatronT5SpeechLMModel.load_from_checkpoint( + checkpoint_path=checkpoint_path, trainer=trainer, cfg=cfg.model + ) + model.eval() + model = model.cuda() + trainer.test(model) + + +if __name__ == '__main__': + main() diff --git a/examples/tts/ssl_tts.py b/examples/tts/ssl_tts.py index a96dccb930ab..a50997a8f432 100644 --- a/examples/tts/ssl_tts.py +++ b/examples/tts/ssl_tts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import ssl_tts diff --git a/examples/tts/tacotron2.py b/examples/tts/tacotron2.py index a5446c35f775..6c4a15d98ef2 100755 --- a/examples/tts/tacotron2.py +++ b/examples/tts/tacotron2.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import Tacotron2Model diff --git a/examples/tts/tacotron2_finetune.py b/examples/tts/tacotron2_finetune.py index a0531f1f2801..f8d4d1dcaad0 100644 --- a/examples/tts/tacotron2_finetune.py +++ b/examples/tts/tacotron2_finetune.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import Tacotron2Model diff --git a/examples/tts/univnet.py b/examples/tts/univnet.py index 91aafa661842..ac6949405fd5 100644 --- a/examples/tts/univnet.py +++ b/examples/tts/univnet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import UnivNetModel diff --git a/examples/tts/vits.py b/examples/tts/vits.py index 75e0d827018a..6eeebd3ea15a 100644 --- a/examples/tts/vits.py +++ b/examples/tts/vits.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.tts.models.vits import VitsModel from nemo.core.config import hydra_runner diff --git a/examples/tts/waveglow.py b/examples/tts/waveglow.py index 66b13491abd4..3bcd008ab5e0 100755 --- a/examples/tts/waveglow.py +++ b/examples/tts/waveglow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import WaveGlowModel diff --git a/examples/vision/convert_ckpt_to_nemo.py b/examples/vision/convert_ckpt_to_nemo.py index 14876f6931f9..e0cf773f98c2 100644 --- a/examples/vision/convert_ckpt_to_nemo.py +++ b/examples/vision/convert_ckpt_to_nemo.py @@ -28,8 +28,8 @@ from argparse import ArgumentParser import torch -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.vision.models.megatron_vit_classification_models import MegatronVitClassificationModel diff --git a/examples/vision/vision_transformer/megatron_vit_classification_evaluate.py b/examples/vision/vision_transformer/megatron_vit_classification_evaluate.py index e827e4db73c7..f7c384809702 100644 --- a/examples/vision/vision_transformer/megatron_vit_classification_evaluate.py +++ b/examples/vision/vision_transformer/megatron_vit_classification_evaluate.py @@ -15,9 +15,9 @@ import os import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from torch.utils.data import DataLoader from tqdm import tqdm @@ -38,7 +38,8 @@ def main(cfg) -> None: plugins = [] strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, find_unused_parameters=False, # we don't use DDP for async grad allreduce + no_ddp_communication_hook=True, + find_unused_parameters=False, # we don't use DDP for async grad allreduce ) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) @@ -82,7 +83,10 @@ def main(cfg) -> None: model.eval() val_transform = ClassificationTransform(model.cfg, (model.cfg.img_h, model.cfg.img_w), train=False) - val_data = ImageFolder(root=cfg.model.data.imagenet_val, transform=val_transform,) + val_data = ImageFolder( + root=cfg.model.data.imagenet_val, + transform=val_transform, + ) def dummy(): return @@ -91,12 +95,20 @@ def dummy(): trainer.strategy.launcher.launch(dummy, trainer=trainer) trainer.strategy.setup_environment() - test_loader = DataLoader(val_data, batch_size=cfg.model.micro_batch_size, num_workers=cfg.model.data.num_workers,) + test_loader = DataLoader( + val_data, + batch_size=cfg.model.micro_batch_size, + num_workers=cfg.model.data.num_workers, + ) autocast_dtype = torch_dtype_from_precision(trainer.precision) - with torch.no_grad(), torch.cuda.amp.autocast( - enabled=autocast_dtype in (torch.half, torch.bfloat16), dtype=autocast_dtype, + with ( + torch.no_grad(), + torch.cuda.amp.autocast( + enabled=autocast_dtype in (torch.half, torch.bfloat16), + dtype=autocast_dtype, + ), ): total = correct = 0.0 for tokens, labels in tqdm(test_loader): diff --git a/examples/vision/vision_transformer/megatron_vit_classification_infer.py b/examples/vision/vision_transformer/megatron_vit_classification_infer.py index a757eb7a1c1f..f50ccf1c325c 100644 --- a/examples/vision/vision_transformer/megatron_vit_classification_infer.py +++ b/examples/vision/vision_transformer/megatron_vit_classification_infer.py @@ -16,10 +16,10 @@ import os import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf.omegaconf import OmegaConf, open_dict from PIL import Image -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector @@ -63,7 +63,8 @@ def main(cfg) -> None: plugins = [] strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, find_unused_parameters=False, # we don't use DDP for async grad allreduce + no_ddp_communication_hook=True, + find_unused_parameters=False, # we don't use DDP for async grad allreduce ) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) @@ -107,7 +108,10 @@ def main(cfg) -> None: model.eval() test_transform = ClassificationTransform(cfg.model, (model_cfg.img_h, model_cfg.img_w), train=False) - test_data = ImageFolderDataset(folder_path=cfg.data_path, transform=test_transform,) + test_data = ImageFolderDataset( + folder_path=cfg.data_path, + transform=test_transform, + ) test_loader = DataLoader(test_data, batch_size=8) def dummy(): @@ -119,8 +123,12 @@ def dummy(): autocast_dtype = torch_dtype_from_precision(trainer.precision) - with torch.no_grad(), torch.cuda.amp.autocast( - enabled=autocast_dtype in (torch.half, torch.bfloat16), dtype=autocast_dtype, + with ( + torch.no_grad(), + torch.cuda.amp.autocast( + enabled=autocast_dtype in (torch.half, torch.bfloat16), + dtype=autocast_dtype, + ), ): class_names = [] for tokens in test_loader: diff --git a/nemo/README.md b/nemo/README.md index a6025e77822a..ebc23f4d5803 100644 --- a/nemo/README.md +++ b/nemo/README.md @@ -2,7 +2,12 @@ NeMo (**Ne**ural **Mo**dules) is a toolkit for creating AI applications built ar **NeMo Core** provides common APIs all modules and models have to implement. -**NeMo Collections** +**NeMo 2.0 Collections** + +* LLM - A collection of data modules, models, configurations, and recipes for building training and parameter-efficient fine-tuning (PEFT) pipelines, including decoder-only models like those in the Llama, Gemma, and Mamba families. +* VLM - A collection of data modules, models, configurations, and recipes for training and PEFT pipelines in vision-language models. + +**NeMo 1.0 Collections** * ASR - collection of modules and models for building speech recognition networks * TTS - collection of modules and models for building speech synthesis networks diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index c63c73323797..76537a8b2b78 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -19,9 +19,9 @@ from typing import Any, List, Optional, Union import torch +from lightning.pytorch.callbacks import BasePredictionWriter from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf.listconfig import ListConfig -from pytorch_lightning.callbacks import BasePredictionWriter from torch.utils.data import ChainDataset from nemo.collections.asr.data import audio_to_text, audio_to_text_dali diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index f18fe02d2ed8..969966839dde 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -21,8 +21,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.asr.data.audio_to_text_lhotse_prompted import ( @@ -62,7 +62,6 @@ from nemo.utils import logging, model_utils from nemo.utils.decorators import deprecated - __all__ = ['EncDecMultiTaskModel'] diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index b49ef50583a7..f84ece6d24ce 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -21,8 +21,8 @@ from typing import Any, Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from torchmetrics import Accuracy from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index ddcc269bedcc..1f03cec59af7 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -22,8 +22,8 @@ from typing import Any, List, Optional, Union import torch +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.utilities import rank_zero_only from tqdm import tqdm from nemo.collections.asr.metrics.der import score_labels @@ -49,7 +49,6 @@ from nemo.core.classes import Model from nemo.utils import logging, model_utils - __all__ = ['ClusteringDiarizer'] _MODEL_CONFIG_YAML = "model_config.yaml" diff --git a/nemo/collections/asr/models/confidence_ensemble.py b/nemo/collections/asr/models/confidence_ensemble.py index c6b2846085af..932d221be0f8 100644 --- a/nemo/collections/asr/models/confidence_ensemble.py +++ b/nemo/collections/asr/models/confidence_ensemble.py @@ -18,8 +18,8 @@ import joblib import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.models.asr_model import ASRModel from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 993c7dc6b298..3df6a7352c4d 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -18,8 +18,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.asr.data import audio_to_text_dataset @@ -43,7 +43,6 @@ from nemo.utils import logging from nemo.utils.decorators import deprecated - __all__ = ['EncDecCTCModel'] diff --git a/nemo/collections/asr/models/hybrid_asr_tts_models.py b/nemo/collections/asr/models/hybrid_asr_tts_models.py index 628395e04f94..89a7e1289675 100644 --- a/nemo/collections/asr/models/hybrid_asr_tts_models.py +++ b/nemo/collections/asr/models/hybrid_asr_tts_models.py @@ -19,8 +19,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast import torch +from lightning.pytorch import Trainer from omegaconf import MISSING, DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.nn.utils.rnn import pad_sequence from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs @@ -324,7 +324,9 @@ def __setattr__(self, name, value): return super().__setattr__(name, value) def setup_optimization( - self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, + self, + optim_config: Optional[Union[DictConfig, Dict]] = None, + optim_kwargs: Optional[Dict[str, Any]] = None, ): """ Setup optimizer and scheduler. Ensure tts model is frozen. @@ -430,7 +432,8 @@ def _get_batch_spect(self, batch: Union[TextToTextBatch, TextOrAudioToTextBatch, elif isinstance(batch, TextOrAudioToTextBatch): tts_spectrogram, tts_spectrogram_len = self._get_tts_spectrogram(batch.tts_texts, batch.speakers) asr_spectrogram, asr_spectrogram_len = self.asr_model.preprocessor( - input_signal=batch.audio_signals, length=batch.audio_signal_lengths, + input_signal=batch.audio_signals, + length=batch.audio_signal_lengths, ) spectrogram = pad_sequence( diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 1d437a19a86b..7e8720ee3ad8 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import _AudioTextDataset diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 028073d7ca7f..34dd9aae5711 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -19,8 +19,8 @@ from typing import Any, List, Optional, Tuple import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from tqdm.auto import tqdm from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs diff --git a/nemo/collections/asr/models/k2_sequence_models.py b/nemo/collections/asr/models/k2_sequence_models.py index 087e9e41b85d..b60d08afe635 100644 --- a/nemo/collections/asr/models/k2_sequence_models.py +++ b/nemo/collections/asr/models/k2_sequence_models.py @@ -14,8 +14,8 @@ from typing import List, Optional +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel @@ -76,7 +76,11 @@ def change_vocabulary(self, new_vocabulary: List[str]): @typecheck() def forward( - self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, ): """ Forward pass of the model. @@ -159,7 +163,11 @@ def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): @typecheck() def forward( - self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, ): """ Forward pass of the model. diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 08c304e4c52c..37391879547b 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -24,8 +24,8 @@ import soundfile as sf import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from sklearn.metrics import roc_curve from torchmetrics import Accuracy from tqdm import tqdm diff --git a/nemo/collections/asr/models/msdd_models.py b/nemo/collections/asr/models/msdd_models.py index c88275dcacd3..d30411f01bcc 100644 --- a/nemo/collections/asr/models/msdd_models.py +++ b/nemo/collections/asr/models/msdd_models.py @@ -25,11 +25,11 @@ import numpy as np import torch from hydra.utils import instantiate +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, open_dict from pyannote.core import Annotation from pyannote.metrics.diarization import DiarizationErrorRate -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities import rank_zero_only from tqdm import tqdm from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechMSDDInferDataset, AudioToSpeechMSDDTrainDataset diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 25890ec716c8..c92bcfaaef7a 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import _AudioTextDataset diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index ce3b6bc89bce..a6408b5e935e 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -19,8 +19,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.asr.data import audio_to_text_dataset diff --git a/nemo/collections/asr/models/ssl_models.py b/nemo/collections/asr/models/ssl_models.py index 633a00d73f5e..9150da7bf7c2 100644 --- a/nemo/collections/asr/models/ssl_models.py +++ b/nemo/collections/asr/models/ssl_models.py @@ -17,8 +17,8 @@ import torch import torch.nn as nn +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.asr.data import audio_to_text_dataset, ssl_dataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 089186e142bf..8d0f2b2223a3 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -22,8 +22,8 @@ import editdistance import torch import torch.distributed as dist +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from torchmetrics.text import SacreBLEUScore from tqdm.auto import tqdm diff --git a/nemo/collections/asr/parts/utils/wfst_utils.py b/nemo/collections/asr/parts/utils/wfst_utils.py index 31f394fb60ac..9dbb9fc751b2 100644 --- a/nemo/collections/asr/parts/utils/wfst_utils.py +++ b/nemo/collections/asr/parts/utils/wfst_utils.py @@ -32,7 +32,7 @@ import kaldifst # check that kaldifst package is not empty - # Note: pytorch_lightning.utilities.imports.package_available may not help here + # Note: lightning.pytorch.utilities.imports.package_available may not help here kaldifst.StdVectorFst() _KALDIFST_AVAILABLE = True except (ImportError, ModuleNotFoundError, AttributeError): diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index e1732c1658b7..60c16f756f58 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -22,8 +22,8 @@ import librosa import soundfile as sf import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from tqdm import tqdm from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index cd9f47b98096..8e2206afcef1 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -17,8 +17,8 @@ import einops import hydra import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.audio.models.audio_to_audio import AudioToAudioModel from nemo.core.classes.common import PretrainedModelInfo, typecheck diff --git a/nemo/collections/audio/parts/utils/callbacks.py b/nemo/collections/audio/parts/utils/callbacks.py index 093d5a11f419..ff975c93ecc7 100644 --- a/nemo/collections/audio/parts/utils/callbacks.py +++ b/nemo/collections/audio/parts/utils/callbacks.py @@ -16,10 +16,10 @@ import einops import torch -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.loggers.logger import Logger -from pytorch_lightning.loggers.wandb import WandbLogger +from lightning.pytorch import Callback, LightningModule, Trainer +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers.logger import Logger +from lightning.pytorch.loggers.wandb import WandbLogger from nemo.utils import logging from nemo.utils.decorators import experimental diff --git a/nemo/collections/common/callbacks/callbacks.py b/nemo/collections/common/callbacks/callbacks.py index 1a6c011c38df..754b33726faf 100644 --- a/nemo/collections/common/callbacks/callbacks.py +++ b/nemo/collections/common/callbacks/callbacks.py @@ -13,15 +13,14 @@ # limitations under the License. import time -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_only +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities import rank_zero_only # from sacrebleu import corpus_bleu class LogEpochTimeCallback(Callback): - """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log - """ + """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log""" @rank_zero_only def on_train_epoch_start(self, trainer, pl_module): diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 2f295bf67354..f866a2639d63 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -17,11 +17,11 @@ import threading from typing import Any, Dict, Iterable -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning import Callback -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.rank_zero import rank_zero_info +from lightning.pytorch import Callback +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.rank_zero import rank_zero_info class EMA(Callback): @@ -40,7 +40,11 @@ class EMA(Callback): """ def __init__( - self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False, + self, + decay: float, + validate_original_weights: bool = False, + every_n_steps: int = 1, + cpu_offload: bool = False, ): if not (0 <= decay <= 1): raise MisconfigurationException("EMA decay value must be between 0 and 1") @@ -149,7 +153,9 @@ def on_load_checkpoint( def ema_update(ema_model_tuple, current_model_tuple, decay): torch._foreach_mul_(ema_model_tuple, decay) torch._foreach_add_( - ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ema_model_tuple, + current_model_tuple, + alpha=(1.0 - decay), ) @@ -272,7 +278,13 @@ def update(self): if self.device.type == 'cpu': self.thread = threading.Thread( - target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,), + target=run_ema_update_cpu, + args=( + self.ema_params, + current_model_state, + self.decay, + self.stream, + ), ) self.thread.start() diff --git a/nemo/collections/common/metrics/perf_metrics.py b/nemo/collections/common/metrics/perf_metrics.py index d668d29c42ff..daad92ce95ea 100644 --- a/nemo/collections/common/metrics/perf_metrics.py +++ b/nemo/collections/common/metrics/perf_metrics.py @@ -15,7 +15,7 @@ from typing import Any, Dict, List, Optional import numpy as np -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from nemo.collections.common.parts.perf_metrics_utils import LLM_VOCAB_SIZE_MAP, read_tb_log from nemo.utils import logging diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index b16ac50e4d56..915f406a3e88 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -308,6 +308,132 @@ def __init__( super().__init__(data) +class InstructionTuningAudioText(_Collection): + """`AudioText` collector from asr structured json files.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='InstructionTuningText', + field_names='id context context_type context_duration question question_type answer answer_type answer_duration speaker', + ) + + def __init__( + self, + manifests_files: Union[str, List[str]], + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + """Parse lists of audio files, durations and transcripts texts. + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + *args: Args to pass to `AudioText` constructor. + **kwargs: Kwargs to pass to `AudioText` constructor. + """ + + output_type = self.OUTPUT_TYPE + self.use_phoneme_tokenizer = use_phoneme_tokenizer + data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0 + if index_by_file_id: + self.mapping = {} + + for item in manifest.item_iter(manifests_files): + + id = item['id'] + context = item['context'] + context_duration = item['context_duration'] + context_type = item['context_type'] + question = item['question'] + question_type = item['question_type'] + speaker = item['speaker'] + answer = item['answer'] + answer_duration = item['answer_duration'] + answer_type = item['answer_type'] + task = item['task'] + + task = 'tts' if task is None else task + duration = answer_duration if task == 'tts' else context_duration + if min_duration is not None and duration < min_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + num_filtered += 1 + continue + + # Check segment length + approx_context_len = min(self._get_len(context_type, context, context_duration) * 0.3, 400) + approx_question_len = self._get_len(question_type, question, None) + approx_answer_len = self._get_len(answer_type, answer, answer_duration) + + if ( + decoder_only_model and approx_context_len + approx_question_len + approx_answer_len >= max_seq_length + ) or (approx_context_len + approx_question_len >= max_seq_length or approx_answer_len >= max_seq_length): + duration_filtered += duration + num_filtered += 1 + continue + + total_duration += duration + data.append( + output_type( + id, + context, + context_type, + context_duration, + question, + question_type, + answer, + answer_type, + answer_duration, + speaker, + ) + ) + + if index_by_file_id: + file_id, _ = os.path.splitext(os.path.basename(context)) + if ".context" in file_id: + file_id = file_id[:-8] + if file_id not in self.mapping: + self.mapping[file_id] = [] + self.mapping[file_id].append(len(data) - 1) + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + + super().__init__(data) + + def _get_len(self, field_type, data, duration_data): + if field_type == "SPEECH": + return duration_data * 76 # TODO: add explanation for the hardcoded value. + elif field_type == "TEXT": + if self.use_phoneme_tokenizer: + # Approx len is number of characters + return len(data) + else: + return len(data.split(' ')) + 3 # # TODO: add explanation for the hardcoded value. + elif field_type == "TOKENS": + return len(data) + 3 + else: + raise ValueError(f"Unknown field type {field_type}.") + + class ASRAudioText(AudioText): """`AudioText` collector from asr structured json files.""" diff --git a/nemo/collections/common/parts/preprocessing/manifest.py b/nemo/collections/common/parts/preprocessing/manifest.py index 1d49bd7c7019..e2ad08bd04c2 100644 --- a/nemo/collections/common/parts/preprocessing/manifest.py +++ b/nemo/collections/common/parts/preprocessing/manifest.py @@ -110,6 +110,8 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: item['audio_file'] = item.pop('audio_filename') elif 'audio_filepath' in item: item['audio_file'] = item.pop('audio_filepath') + elif 'context' in item: + item['audio_file'] = item['context'] # Video File if 'video_filename' in item: @@ -132,7 +134,9 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: item['video_file'] = get_full_path(audio_file=item['video_file'], manifest_file=manifest_file) # Duration. - if 'duration' not in item: + if 'context_duration' in item and 'duration' not in item: + item['duration'] = item['context_duration'] + elif 'duration' not in item: raise ValueError( f"Manifest file {manifest_file} has invalid json line structure: {line} without proper duration key." ) @@ -184,6 +188,15 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: orig_sr=item.get('orig_sample_rate', None), token_labels=item.get('token_labels', None), lang=item.get('lang', None), + context=item.get('context', None), + context_type=item.get('context_type', None), + context_duration=item.get('context_duration', None), + answer=item.get('answer', None), + answer_type=item.get('answer_type', None), + answer_duration=item.get('answer_duration', None), + question=item.get('question', None), + question_type=item.get('question_type', None), + task=item.get('task', None), ) return item @@ -247,7 +260,7 @@ def get_full_path( if ( (len(audio_file) < audio_file_len_limit) and not os.path.isabs(audio_file) - and not os.path.isfile(audio_file) + # and not os.path.isfile(audio_file) # Commented out because it slows down dataloading ): # If audio_file is not available and the path is not absolute, the full path is assumed # to be relative to the manifest file parent directory or data directory. diff --git a/nemo/collections/common/parts/ptl_overrides.py b/nemo/collections/common/parts/ptl_overrides.py index 0225ecd50fee..263c865f8270 100644 --- a/nemo/collections/common/parts/ptl_overrides.py +++ b/nemo/collections/common/parts/ptl_overrides.py @@ -13,11 +13,11 @@ # limitations under the License. import torch -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin class NeMoMixedPrecisionPlugin(MixedPrecisionPlugin): - def __init__(self, init_scale: float = 2 ** 32, growth_interval: int = 1000) -> None: + def __init__(self, init_scale: float = 2**32, growth_interval: int = 1000) -> None: super().__init__(precision=16) self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval) diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index a8ea949019c1..56a4b04dfe0f 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -25,7 +25,7 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging -__all__ = ['SentencePieceTokenizer', 'create_spt_model'] +__all__ = ['SentencePieceTokenizer', 'SentencePieceSpeechLLMTTSTokenizer', 'create_spt_model'] class SentencePieceTokenizer(TokenizerSpec, ChatTemplateMixin): @@ -315,6 +315,14 @@ def vocab(self): return main_vocab + special_tokens +class SentencePieceSpeechLLMTTSTokenizer(SentencePieceTokenizer): + def add_phone_tokens_to_special_tokens(self): + for i, word in enumerate(self.vocab): + if word.startswith("p{"): + self.special_token_to_id[word] = i + self.id_to_special_token[i] = word + + def create_spt_model( data_file: str, vocab_size: int, diff --git a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py index 67a26609dd51..07747528363a 100644 --- a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py @@ -15,9 +15,9 @@ import logging from typing import Any, Dict, Literal +from lightning.pytorch.utilities.types import EVAL_DATALOADERS from megatron.core import parallel_state from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset -from pytorch_lightning.utilities.types import EVAL_DATALOADERS from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule diff --git a/nemo/collections/diffusion/data/diffusion_fake_datamodule.py b/nemo/collections/diffusion/data/diffusion_fake_datamodule.py index 6cb686c1c305..a9fc7ad5b484 100644 --- a/nemo/collections/diffusion/data/diffusion_fake_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_fake_datamodule.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader from nemo.collections.diffusion.models.model import DiTConfig diff --git a/nemo/collections/diffusion/train.py b/nemo/collections/diffusion/train.py index 5428e0eeefa2..404602084b85 100644 --- a/nemo/collections/diffusion/train.py +++ b/nemo/collections/diffusion/train.py @@ -14,13 +14,13 @@ import os +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.enums import AttnMaskType -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 7bdcf88b8956..5d5f1983fea7 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -99,6 +99,9 @@ NemotronModel, NVIDIAMambaConfig8B, NVIDIAMambaHybridConfig8B, + Phi3Config, + Phi3ConfigMini, + Phi3Model, Qwen2Config, Qwen2Config1P5B, Qwen2Config7B, @@ -150,6 +153,9 @@ "Nemotron4Config15B", "Nemotron4Config340B", "NemotronConfig", + "Phi3Config", + "Phi3ConfigMini", + "Phi3Model", "SSMConfig", "BaseMambaConfig130M", "BaseMambaConfig370M", @@ -219,7 +225,7 @@ try: import nemo_run as run - from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, train, validate + from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, ptq, train, validate from nemo.collections.llm.recipes import * # noqa __all__.extend( @@ -231,6 +237,7 @@ "validate", "finetune", "generate", + "ptq", ] ) except ImportError as error: @@ -242,3 +249,10 @@ __all__.append("deploy") except ImportError as error: logging.warning(f"The deploy module could not be imported: {error}") + +try: + from nemo.collections.llm.api import evaluate + + __all__.append("evaluate") +except ImportError as error: + logging.warning(f"The evaluate module could not be imported: {error}") diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index fdceff5d959e..e7e660060f54 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from rich.console import Console from typing_extensions import Annotated import nemo.lightning as nl +from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig from nemo.lightning import ( AutoResume, NeMoLogger, @@ -68,7 +68,8 @@ def train( resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. export (Optional[str]): Filename to save the exported checkpoint after training. model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. @@ -84,7 +85,7 @@ def train( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> train(model, data, trainer, tokenizer="data") + >>> llm.train(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ app_state = _setup( @@ -186,7 +187,7 @@ def finetune( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> finetune(model, data, trainer, peft=llm.peft.LoRA()]) + >>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()]) PosixPath('/path/to/log_dir') """ @@ -224,7 +225,8 @@ def validate( resume (Optional[AutoResume]): Resume from a checkpoint for validation. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. Returns: @@ -237,7 +239,7 @@ def validate( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> validate(model, data, trainer, tokenizer="data") + >>> llm.validate(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ app_state = _setup( @@ -256,84 +258,67 @@ def validate( return app_state.exp_dir -def get_trtllm_deployable( - nemo_checkpoint, - model_type, - triton_model_repository, - num_gpus, - tensor_parallelism_size, - pipeline_parallelism_size, - max_input_len, - max_output_len, - max_batch_size, - dtype, -): - from nemo.export.tensorrt_llm import TensorRTLLM +@run.cli.entrypoint(name="ptq", namespace="llm") +def ptq( + nemo_checkpoint: str, + calib_tp: int = 1, + calib_pp: int = 1, + quantization_config: Annotated[Optional[QuantizationConfig], run.Config[QuantizationConfig]] = None, + export_config: Optional[Union[ExportConfig, run.Config[ExportConfig]]] = None, +) -> Path: + # TODO: Fix "nemo_run.cli.cli_parser.CLIException: An unexpected error occurred (Argument: , Context: {})" + """ + Applies Post-Training Quantization (PTQ) for a model using the specified quantization and export configs. It runs + calibration for a small dataset to collect scaling factors low-precision GEMMs used by desired quantization method. + This function produces TensorRT-LLM checkpoint ready for deployment using nemo.export and nemo.deploy modules + or direcly using TensorRT-LLM library. + The function can be used through the NeMo CLI in the following way: + ```bash + # Run calibration using tensor parallel set to 8 and export quantized checkpoint with tensor parallel equal 2 + nemo llm ptq nemo_checkpoint=/models/Llama-3-70B \ + export_config.path=/models/Llama-3-70B-FP8 \ + calib_tp=8 \ + export_config.inference_tensor_parallel=2 + # Choose different quantization method, for example, INT8 SmoothQuant + nemo llm ptq nemo_checkpoint=/models/Llama-3-8B \ + export_config.path=/models/Llama-3-8B-INT8_SQ \ + quantization_config.algorithm=int8_sq + ``` + Args: + nemo_checkpoint (str): The path to model to be quantized. + calib_tp (int): Calibration tensor parallelism. + calib_pp (int): Calibration pipeline parallelism. + quantization_config (QuantizationConfig): Configuration for quantization algorithm. + export_config (ExportConfig): Export configuration for TensorRT-LLM checkpoint. + Returns: + Path: The path where the quantized checkpoint has been saved after calibration. + """ + if export_config.path is None: + raise ValueError("The export_config.path needs to be specified, got None.") - if triton_model_repository is None: - trt_llm_path = "/tmp/trt_llm_model_dir/" - Path(trt_llm_path).mkdir(parents=True, exist_ok=True) - else: - trt_llm_path = triton_model_repository + from nemo.collections.llm import quantization - if nemo_checkpoint is None and triton_model_repository is None: - raise ValueError( - "The provided model repository is not a valid TensorRT-LLM model " - "directory. Please provide a --nemo_checkpoint or a TensorRT-LLM engine." - ) + quantizer = quantization.Quantizer(quantization_config, export_config) - if nemo_checkpoint is None and not os.path.isdir(triton_model_repository): - raise ValueError( - "The provided model repository is not a valid TensorRT-LLM model " - "directory. Please provide a --nemo_checkpoint or a valid TensorRT-LLM engine." - ) + model = quantization.load_with_modelopt_layer_spec(nemo_checkpoint, calib_tp, calib_pp) - if nemo_checkpoint is not None and model_type is None: - raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") + model = quantizer.quantize(model) - trt_llm_exporter = TensorRTLLM( - model_dir=trt_llm_path, - load_model=(nemo_checkpoint is None), - ) + quantizer.export(model, nemo_checkpoint) - if nemo_checkpoint is not None: - try: - logging.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.") - trt_llm_exporter.export( - nemo_checkpoint_path=nemo_checkpoint, - model_type=model_type, - n_gpus=num_gpus, - tensor_parallelism_size=tensor_parallelism_size, - pipeline_parallelism_size=pipeline_parallelism_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - dtype=dtype, - ) - except Exception as error: - raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) - - return trt_llm_exporter - - -def store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response): - args_dict = { - "triton_service_ip": triton_http_address, - "triton_service_port": triton_port, - "triton_request_timeout": triton_request_timeout, - "openai_format_response": openai_format_response, - } - with open("nemo/deploy/service/config.json", "w") as f: - json.dump(args_dict, f) + console = Console() + console.print(f"[green]✓ PTQ succeded, quantized checkpoint exported to {export_config.path}[/green]") + + return export_config.path @run.cli.entrypoint(namespace="llm") def deploy( nemo_checkpoint: Path = None, model_type: str = "llama", - triton_model_name: str = "xxx", + triton_model_name: str = 'triton_model', triton_model_version: Optional[int] = 1, - triton_port: int = 8080, + triton_port: int = 8000, triton_http_address: str = "0.0.0.0", triton_request_timeout: int = 60, triton_model_repository: Path = None, @@ -344,21 +329,61 @@ def deploy( max_input_len: int = 256, max_output_len: int = 256, max_batch_size: int = 8, - start_rest_service: bool = False, + start_rest_service: bool = True, rest_service_http_address: str = "0.0.0.0", - rest_service_port: int = 8000, - openai_format_response: bool = False, + rest_service_port: int = 8080, + openai_format_response: bool = True, + output_generation_logits: bool = True, ): + """ + Deploys nemo model on a PyTriton server by converting the nemo ckpt to trtllm. + Also starts rest service that is used to send OpenAI API compatible input request + to the PyTiton server. + + Args: + nemo_checkpoint (Path): Path for nemo checkpoint. + model_type (str): Type of the model. Choices: gpt, llama, falcon, starcoder. Default: llama. + triton_model_name (str): Name for the model that gets deployed on PyTriton. Please ensure that the same model + name is passed to the evalute method for the model to be accessible while sending evalution requests. + Default: 'triton_model'. + triton_model_version (Optional[int]): Version for the triton model. Default: 1. + triton_port (int): Port for the PyTriton server. Default: 8000. + triton_http_address (str): HTTP address for the PyTriton server. Default: "0.0.0.0". + triton_request_timeout (int): Timeout in seconds for Triton server. Default: 60. + triton_model_repository (Path): Folder for the trt-llm conversion, trt-llm engine gets saved in this specified + path. If None, saves it in /tmp dir. Default: None. + num_gpus (int): Number of GPUs for export to trtllm and deploy. Default: 1. + tensor_parallelism_size (int): Tensor parallelism size. Default: 1. + pipeline_parallelism_size (int): Pipeline parallelism size. Default: 1. + dtype (str): dtype of the TensorRT-LLM model. Default: "bfloat16". + max_input_len (int): Max input length of the model. Default: 256. + max_output_len (int): Max output length of the model. Default: 256. + max_batch_size (int): Max batch size of the model. Default: 8. + start_rest_service (bool): Start rest service that is used to send evaluation requests to the PyTriton server. + Needs to be True to be able to run evaluation. Default: True. + rest_service_http_address (str): HTTP address for the rest service. Default: "0.0.0.0". + rest_service_port (int): Port for the rest service. Default: 8080. + openai_format_response (bool): Return the response from PyTriton server in OpenAI compatible format. Needs to + be True while running evaluation. Default: True. + output_generation_logits (bool): If True builds trtllm engine with gather_generation_logits set to True. + generation_logits are used to compute the logProb of the output token. Default: True. + """ + from nemo.collections.llm import deploy from nemo.deploy import DeployPyTriton + deploy.unset_environment_variables() if start_rest_service: if triton_port == rest_service_port: logging.error("REST service port and Triton server port cannot use the same port.") return - # Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py - store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response) - - triton_deployable = get_trtllm_deployable( + # Store triton ip, port and other args relevant for REST API as env vars to be accessible by rest_model_api.py + os.environ['TRITON_HTTP_ADDRESS'] = triton_http_address + os.environ['TRITON_PORT'] = str(triton_port) + os.environ['TRITON_REQUEST_TIMEOUT'] = str(triton_request_timeout) + os.environ['OPENAI_FORMAT_RESPONSE'] = str(openai_format_response) + os.environ['OUTPUT_GENERATION_LOGITS'] = str(output_generation_logits) + + triton_deployable = deploy.get_trtllm_deployable( nemo_checkpoint, model_type, triton_model_repository, @@ -369,6 +394,7 @@ def deploy( max_output_len, max_batch_size, dtype, + output_generation_logits, ) try: @@ -383,6 +409,7 @@ def deploy( logging.info("Triton deploy function will be called.") nm.deploy() + nm.run() except Exception as error: logging.error("Error message has occurred during deploy function. Error message: " + str(error)) return @@ -416,6 +443,81 @@ def deploy( nm.stop() +def evaluate( + nemo_checkpoint_path: Path, + url: str = "http://0.0.0.0:8080/v1", + model_name: str = "triton_model", + eval_task: str = "gsm8k", + num_fewshot: Optional[int] = None, + limit: Optional[Union[int, float]] = None, + bootstrap_iters: int = 100000, + # inference params + max_tokens_to_generate: Optional[int] = 256, + temperature: Optional[float] = 0.000000001, + top_p: Optional[float] = 0.0, + top_k: Optional[int] = 1, + add_bos: Optional[bool] = False, +): + """ + Evaluates nemo model deployed on PyTriton server (via trtllm) using lm-evaluation-harness + (https://github.com/EleutherAI/lm-evaluation-harness/tree/main). + + Args: + nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt + which is required to tokenize the evaluation input and output prompts. + url (str): rest service url and port that were used in the deploy method above in the format: + http://{rest_service_http}:{rest_service_port}. Post requests with evaluation input prompts + (from lm-eval-harness) are sent to this url which is then passed to the model deployed on PyTriton server. + The rest service url and port serve as the entry point to evaluate model deployed on PyTriton server. + model_name (str): Name of the model that is deployed on PyTriton server. It should be the same as + triton_model_name passed to the deploy method above to be able to launch evaluation. Deafult: "triton_model". + eval_task (str): task to be evaluated on. For ex: "gsm8k", "gsm8k_cot", "mmlu", "lambada". Default: "gsm8k". + These are the tasks that are supported currently. Any other task of type generate_until or loglikelihood from + lm-evaluation-harness can be run, but only the above mentioned ones are tested. Tasks of type + loglikelihood_rolling are not supported yet. + num_fewshot (int): number of examples in few-shot context. Default: None. + limit (Union[int, float]): Limit the number of examples per task. If <1 (i.e float val between 0 and 1), limit + is a percentage of the total number of examples. If int say x, then run evaluation only on x number of samples + from the eval dataset. Default: None, which means eval is run the entire dataset. + bootstrap_iters (int): Number of iterations for bootstrap statistics, used when calculating stderrs. Set to 0 + for no stderr calculations to be performed. Default: 100000. + # inference params + max_tokens_to_generate (int): max tokens to generate. Default: 256. + temperature: Optional[float]: float value between 0 and 1. temp of 0 indicates greedy decoding, where the token + with highest prob is chosen. Temperature can't be set to 0.0 currently, due to a bug with TRTLLM + (# TODO to be investigated). Hence using a very samll value as the default. Default: 0.000000001. + top_p: Optional[float]: float value between 0 and 1. limits to the top tokens within a certain probability. + top_p=0 means the model will only consider the single most likely token for the next prediction. Default: 0.0. + top_k: Optional[int]: limits to a certain number (K) of the top tokens to consider. top_k=1 means the model + will only consider the single most likely token for the next prediction. Default: 1 + add_bos: Optional[bool]: whether a special token representing the beginning of a sequence should be added when + encoding a string. Default: False since typically for CausalLM its set to False. If needed set add_bos to True. + """ + try: + # lm-evaluation-harness import + from lm_eval import evaluator + except ImportError: + raise ImportError( + "Please ensure that lm-evaluation-harness is installed in your env as it is required " "to run evaluations" + ) + + from nemo.collections.llm import evaluation + + # Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt. + tokenizer = io.load_context(nemo_checkpoint_path + '/context', subpath="model").tokenizer + # Wait for rest service to be ready before starting evaluation + evaluation.wait_for_rest_service(rest_url=f"{url}/v1/health") + # Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate + model = evaluation.NeMoFWLMEval( + model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos + ) + results = evaluator.simple_evaluate( + model=model, tasks=eval_task, limit=limit, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters + ) + + print("score", results['results'][eval_task]) + + @run.cli.entrypoint(name="import", namespace="llm") def import_ckpt( model: pl.LightningModule, diff --git a/nemo/collections/llm/deploy/__init__.py b/nemo/collections/llm/deploy/__init__.py new file mode 100644 index 000000000000..24c102bfa0d2 --- /dev/null +++ b/nemo/collections/llm/deploy/__init__.py @@ -0,0 +1,3 @@ +from nemo.collections.llm.deploy.base import get_trtllm_deployable, unset_environment_variables + +__all__ = ["unset_environment_variables", "get_trtllm_deployable"] diff --git a/nemo/collections/llm/deploy/base.py b/nemo/collections/llm/deploy/base.py new file mode 100644 index 000000000000..e21198f5884b --- /dev/null +++ b/nemo/collections/llm/deploy/base.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +from pathlib import Path + +from nemo.utils import logging + + +def unset_environment_variables() -> None: + """ + SLURM_, PMI_, PMIX_ Variables are needed to be unset for trtllm export to work + on clusters. This method takes care of unsetting these env variables + """ + logging.info("Unsetting all SLURM_, PMI_, PMIX_ Variables") + + # Function to unset variables with a specific prefix + def unset_vars_with_prefix(prefix): + unset_vars = [] + cmd = f"env | grep ^{prefix} | cut -d= -f1" + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + vars_to_unset = result.stdout.strip().split('\n') + for var in vars_to_unset: + if var: # Check if the variable name is not empty + os.environ.pop(var, None) + unset_vars.append(var) + return unset_vars + + # Collect all unset variables across all prefixes + all_unset_vars = [] + + # Unset variables for each prefix + for prefix in ['SLURM_', 'PMI_', 'PMIX_']: + unset_vars = unset_vars_with_prefix(prefix) + all_unset_vars.extend(unset_vars) + + if all_unset_vars: + logging.info(f"Unset env variables: {', '.join(all_unset_vars)}") + else: + logging.info("No env variables were unset.") + + +def get_trtllm_deployable( + nemo_checkpoint, + model_type, + triton_model_repository, + num_gpus, + tensor_parallelism_size, + pipeline_parallelism_size, + max_input_len, + max_output_len, + max_batch_size, + dtype, + output_generation_logits, +): + """ + Exports the nemo checkpoint to trtllm and returns trt_llm_exporter that is used to deploy on PyTriton. + """ + from nemo.export.tensorrt_llm import TensorRTLLM + + if triton_model_repository is None: + trt_llm_path = "/tmp/trt_llm_model_dir/" + Path(trt_llm_path).mkdir(parents=True, exist_ok=True) + else: + trt_llm_path = triton_model_repository + + if nemo_checkpoint is None and triton_model_repository is None: + raise ValueError( + "The provided model repository is not a valid TensorRT-LLM model " + "directory. Please provide a --nemo_checkpoint or a TensorRT-LLM engine." + ) + + if nemo_checkpoint is None and not os.path.isdir(triton_model_repository): + raise ValueError( + "The provided model repository is not a valid TensorRT-LLM model " + "directory. Please provide a --nemo_checkpoint or a valid TensorRT-LLM engine." + ) + + if nemo_checkpoint is not None and model_type is None: + raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") + + trt_llm_exporter = TensorRTLLM( + model_dir=trt_llm_path, + load_model=(nemo_checkpoint is None), + ) + + if nemo_checkpoint is not None: + try: + logging.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.") + trt_llm_exporter.export( + nemo_checkpoint_path=nemo_checkpoint, + model_type=model_type, + n_gpus=num_gpus, + tensor_parallelism_size=tensor_parallelism_size, + pipeline_parallelism_size=pipeline_parallelism_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + dtype=dtype, + gather_generation_logits=output_generation_logits, + ) + except Exception as error: + raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) + + return trt_llm_exporter diff --git a/nemo/collections/llm/evaluation/__init__.py b/nemo/collections/llm/evaluation/__init__.py new file mode 100644 index 000000000000..3012689bb8da --- /dev/null +++ b/nemo/collections/llm/evaluation/__init__.py @@ -0,0 +1,3 @@ +from nemo.collections.llm.evaluation.base import NeMoFWLMEval, wait_for_rest_service + +__all__ = ["NeMoFWLMEval", "wait_for_rest_service"] diff --git a/nemo/collections/llm/evaluation/base.py b/nemo/collections/llm/evaluation/base.py new file mode 100644 index 000000000000..b1734d6f4d43 --- /dev/null +++ b/nemo/collections/llm/evaluation/base.py @@ -0,0 +1,210 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import requests +import torch +import torch.nn.functional as F +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from requests.exceptions import RequestException + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.utils import logging + + +class NeMoFWLMEval(LM): + """ + NeMoFWLMEval is a wrapper class subclassing lm_eval.api.model.LM class, that defines how lm_eval interfaces with + NeMo model deployed on PyTriton server. + Created based on: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.4/docs/model_guide.md + """ + + def __init__(self, model_name, api_url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos): + self.model_name = model_name + self.api_url = api_url + self.tokenizer = tokenizer + self.max_tokens_to_generate = max_tokens_to_generate + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.add_bos = add_bos + super().__init__() + + def _generate_tokens_logits(self, payload, return_text: bool = False, return_logits: bool = False): + """ + A private method that sends post request to the model on PyTriton server and returns either generated text or + logits. + """ + # send a post request to /v1/completions/ endpoint with the payload + response = requests.post(f"{self.api_url}/v1/completions/", json=payload) + response_data = response.json() + + if 'error' in response_data: + raise Exception(f"API Error: {response_data['error']}") + + # Assuming the response is in OpenAI format + if return_text: + # in case of generate_until tasks return just the text + return response_data['choices'][0]['text'] + + if return_logits: + # in case of loglikelihood tasks return the logits + return response_data['choices'][0]['generation_logits'] + + def tokenizer_type(self, tokenizer): + """ + Returns the type of the tokenizer. + """ + if isinstance(tokenizer, AutoTokenizer): + return "AutoTokenizer" + elif isinstance(tokenizer, SentencePieceTokenizer): + return "SentencePieceTokenizer" + else: + raise ValueError( + "Tokenizer type is not one of SentencePieceTokenizer or HF's AutoTokenizer. Please check " + "how to handle special tokens for this tokenizer" + ) + + def loglikelihood(self, requests: list[Instance]): + """ + Defines the loglikelihood request. Takes input requests of type list[Instance] where Instance is a dataclass + defined in lm_eval.api.instance. Each Instance conists of the input prompt, output prompt, request type(here + loglikelihood) and other relevant args like few shot samples. + """ + special_tokens_kwargs = {} + tokenizer_type = self.tokenizer_type(self.tokenizer) + if tokenizer_type == "SentencePieceTokenizer": + special_tokens_kwargs['add_bos'] = self.add_bos + elif tokenizer_type == "AutoTokenizer": + special_tokens_kwargs['add_special_tokens'] = self.add_bos + + results = [] + for request in requests: + # get the input prompt from the request + context = request.arguments[0] + # get the output prompt from the request + continuation = request.arguments[1] + # get encoded tokens of continuation + continuation_enc = self.tokenizer.tokenizer.encode(continuation, **special_tokens_kwargs) + # for SentencePeice consider the encoded tokens from the 2nd token since first encoded token is space. + if self.tokenizer_type(self.tokenizer) == "SentencePieceTokenizer": + continuation_enc = continuation_enc[1:] + num_cont_tokens = len(continuation_enc) + # Update self.max_tokens_to_generate with number of continuation tokens (or output tokens) in the request + self.max_tokens_to_generate = num_cont_tokens + # Create payload to query the model deployed on PyTriton server + payload = { + "model": self.model_name, + "prompt": context, + "max_tokens": self.max_tokens_to_generate, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + # Get the logits from the model + generation_logits = self._generate_tokens_logits(payload, return_logits=True) + # Convert generation_logits to torch tensor to easily get logprobs wo manual implementation of log_softmax + multi_logits = F.log_softmax(torch.tensor(generation_logits[0]), dim=-1) + # Convert encoded continuation tokens to torch tensor + cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0) + # Get the greedy token from the logits (i.e token with the highest prob) + greedy_tokens = multi_logits.argmax(dim=-1) + # Check if all greedy_tokens match the the actual continuation tokens + is_greedy = (greedy_tokens == cont_toks).all() + # Get the logits corresponding to the actual continuation tokens + logits = torch.gather(multi_logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) + # result is tuple of logProb of generating the continuation token and is_greedy + result = (float(logits.sum()), bool(is_greedy)) + + results.append(result) + + return results + + def loglikelihood_rolling(self, requests: list[Instance]): + """ + Defines the loglikelihood_rolling request type. Yet to be implemented. + """ + pass + + def generate_until(self, inputs: list[Instance]): + """ + Defines the generate_until request type. Takes input requests of type list[Instance] where Instance is a + dataclass defined in lm_eval.api.instance. Each Instance conists of the input prompt, output prompt, request + type(here loglikelihood) and other relevant args like few shot samples. + """ + results = [] + for instance in inputs: + # Access the 'arguments' attribute of the Instance which contains the input prompt string + prompt = instance.arguments[0] + # Create payload to query the model deployed on PyTriton server + payload = { + "model": self.model_name, + "prompt": prompt, + "max_tokens": self.max_tokens_to_generate, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + # Get the text generated by the model + generated_text = self._generate_tokens_logits(payload, return_text=True) + + results.append(generated_text) + + return results + + +def wait_for_rest_service(rest_url, max_retries=60, retry_interval=2): + """ + Wait for REST service to be ready. + + Args: + rest_url (str): URL of the REST service's health endpoint + max_retries (int): Maximum number of retry attempts. Defaul: 60. + retry_interval (int): Time to wait between retries in seconds. Default: 2. + + Returns: + bool: True if rest service is ready, False otherwise + """ + + def check_service(url): + """ + Check if the service is ready by making a GET request to its health endpoint. + + Args: + url (str): URL of the service's health endpoint + + Returns: + bool: True if the service is ready, False otherwise + """ + try: + response = requests.get(url, timeout=5) + return response.status_code == 200 + except RequestException: + return False + + for _ in range(max_retries): + rest_ready = check_service(rest_url) + + if rest_ready: + logging.info("REST service is ready.") + return True + + logging.info(f"REST Service not ready yet. Retrying in {retry_interval} seconds...") + time.sleep(retry_interval) + + logging.info("Timeout: REST service did not become ready.") + return False diff --git a/nemo/collections/llm/gpt/data/api.py b/nemo/collections/llm/gpt/data/api.py index 74ecb5272ac2..2ebb30e781d1 100644 --- a/nemo/collections/llm/gpt/data/api.py +++ b/nemo/collections/llm/gpt/data/api.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl from nemo.collections.llm.gpt.data.dolly import DollyDataModule from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index 9d16ea8aa021..8fcef72f3bd9 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl from torch.utils.data import DataLoader from nemo.collections.common.tokenizers import AutoTokenizer diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 5c6b71c74797..46562b6e72c8 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.utils.data import DataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler diff --git a/nemo/collections/llm/gpt/data/mock.py b/nemo/collections/llm/gpt/data/mock.py index 5678597eda0b..f6b4e26ca355 100644 --- a/nemo/collections/llm/gpt/data/mock.py +++ b/nemo/collections/llm/gpt/data/mock.py @@ -14,10 +14,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index cfacde118b89..f659ce72796c 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -18,8 +18,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -import pytorch_lightning as pl -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from nemo.lightning.data import WrappedDataLoader diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index b42ceac564bc..152309536f5b 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -79,6 +79,7 @@ NemotronConfig, NemotronModel, ) +from nemo.collections.llm.gpt.model.phi3mini import Phi3Config, Phi3ConfigMini, Phi3Model from nemo.collections.llm.gpt.model.qwen2 import ( Qwen2Config, Qwen2Config1P5B, @@ -140,6 +141,9 @@ "Nemotron3Config22B", "Nemotron4Config340B", "NemotronModel", + "Phi3Config", + "Phi3ConfigMini", + "Phi3Model", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 6b158a33b226..8c3b47835ab1 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union -import pytorch_lightning as L +import lightning.pytorch as L import torch import torch.distributed from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index eada3f4c3eb8..4d9e7ae026d5 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM diff --git a/nemo/collections/llm/gpt/model/mistral.py b/nemo/collections/llm/gpt/model/mistral.py index a71042e2ba6f..0aa611b4454e 100644 --- a/nemo/collections/llm/gpt/model/mistral.py +++ b/nemo/collections/llm/gpt/model/mistral.py @@ -16,7 +16,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, List, Optional -import pytorch_lightning as pl import torch import torch.nn.functional as F from torch import nn diff --git a/nemo/collections/llm/gpt/model/phi3mini.py b/nemo/collections/llm/gpt/model/phi3mini.py new file mode 100644 index 000000000000..eb0b9c758dd7 --- /dev/null +++ b/nemo/collections/llm/gpt/model/phi3mini.py @@ -0,0 +1,258 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.lightning import OptimizerModule, io, teardown +from nemo.lightning.pytorch.utils import dtype_from_hf + + +@dataclass +class Phi3Config(GPTConfig): + # pylint: disable=C0115,C0116 + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + seq_length: int = 4096 + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + share_embeddings_and_output_weights: bool = False + + +@dataclass +class Phi3ConfigMini(Phi3Config): + # pylint: disable=C0115,C0116 + num_layers: int = 32 + hidden_size: int = 3072 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 32 + num_query_groups: int = 32 + rotary_base: float = 10000.0 + vocab_size: int = 32064 + + +class Phi3Model(GPTModel): + # pylint: disable=C0115,C0116 + def __init__( + self, + config: Optional[Phi3Config] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(config or Phi3Config(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + + +@io.model_importer(Phi3Model, "hf") +class HFPhi3Importer(io.ModelConnector["Phi3ForCausalLM", Phi3Model]): + # pylint: disable=C0115,C0116 + def init(self) -> Phi3Model: + return Phi3Model(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import Phi3ForCausalLM + + # Check if the source is valid model identifier or path + try: + source = Phi3ForCausalLM.from_pretrained(str(self), torch_dtype='auto') + except Exception as e: + raise ValueError(f"Failed to load the model from source '{self}': {e}") + + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Phi3 model to Nemo, model saved to {output_path} in {source.dtype}.") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + # pylint: disable=C0115,C0116 + # Define mapping for mini-4k-instruct + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.self_attn.qkv_proj.weight": "decoder.layers.*.self_attention.linear_qkv.weight", + "model.layers.*.mlp.gate_up_proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + + @property + def tokenizer(self): + # pylint: disable=C0115,C0116 + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(self.save_hf_tokenizer_assets(str(self))) + + @property + def config(self) -> Phi3Config: + # pylint: disable=C0115,C0116 + from transformers import Phi3Config as HFPhi3Config + + source = HFPhi3Config.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = Phi3Config( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + rotary_base=source.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + fp16=(dtype_from_hf(source) == torch.float16), + bf16=(dtype_from_hf(source) == torch.bfloat16), + params_dtype=dtype_from_hf(source), + ) + print("output:", output) + return output + + +@io.model_exporter(Phi3Model, "hf") +class HFPhi3Exporter(io.ModelConnector[Phi3Model, "Phi3ForCausalLM"]): + # pylint: disable=C0115,C0116 + def init(self) -> "Phi3ForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target.cpu().save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + # pylint: disable=C0115,C0116 + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + # Convert source weights to target dtype if needed + for name, param in source.state_dict().items(): + if param.dtype != target.state_dict()[name].dtype: + param.data = param.data.to(target.state_dict()[name].dtype) + + return io.apply_transforms(source, target, mapping=mapping) + + @property + def tokenizer(self): + # pylint: disable=C0115,C0116 + return io.load_context(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFPhi3Config": + # pylint: disable=C0115,C0116 + source: Phi3Config = io.load_context(str(self)).model.config + + from transformers import Phi3Config as HFPhi3Config + + return HFPhi3Config( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=0.02, + rms_norm_eps=1e-05, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key="model.layers.*.self_attn.qkv_proj.weight", + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, qkv_weight): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_size = megatron_config.kv_channels + + old_tensor_shape = qkv_weight.size() + new_q_tensor_shape = (head_num, head_size, old_tensor_shape[1]) + new_kv_tensor_shape = (num_query_groups, head_size, old_tensor_shape[1]) + q, k, v = qkv_weight.split( + [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0 + ) + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights = torch.empty((0, head_size, old_tensor_shape[1])).type_as(qkv_weight) + for i in range(num_query_groups): + qkv_weights = torch.cat((qkv_weights, q[i * heads_per_group : (i + 1) * heads_per_group, :, :])) + qkv_weights = torch.cat((qkv_weights, k[i : i + 1, :, :])) + qkv_weights = torch.cat((qkv_weights, v[i : i + 1, :, :])) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), # phi-3-mini-4k-instruct + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0) + + +__all__ = ["Phi3Config", "Phi3ConfigMini", "Phi3Model"] diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 55d865ec238b..8a3cbc925dad 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -16,9 +16,10 @@ from pathlib import Path from typing import Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.distributed +from lightning.pytorch.trainer.states import TrainerFn from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( @@ -31,7 +32,6 @@ SimpleTextGenerationController, ) from megatron.core.transformer.module import MegatronModule -from pytorch_lightning.trainer.states import TrainerFn import nemo.lightning as nl from nemo.collections.llm.peft import LoRA diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 2f3e0e1e986e..45f72f06741e 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -15,6 +15,7 @@ import os import shutil from dataclasses import dataclass +from pathlib import Path from typing import Optional, Union import torch @@ -75,17 +76,20 @@ class QuantizationConfig: @dataclass class ExportConfig: - """Inference configuration for the quantized TensorRT-LLM engine""" + """Inference configuration for the quantized TensorRT-LLM checkpoint.""" - path: str + path: Union[Path, str] dtype: Union[str, int] = "bf16" decoder_type: Optional[str] = None inference_tensor_parallel: int = 1 inference_pipeline_parallel: int = 1 + def __post_init__(self): + self.path = Path(self.path) + def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: - """Infers the modelopt decoder type from GPTConfig class""" + """Infers the modelopt decoder type from GPTConfig class.""" mapping = [ (llm.Baichuan2Config, "baichuan"), (llm.ChatGLMConfig, "chatglm"), @@ -109,17 +113,17 @@ def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: class Quantizer: - """Post-training quantization (PTQ) and TRT-LLM export of NeMo 2.0 checkpoints. + """Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints. PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving. The process consist of several steps: 1. Loading a Nemo model from disk using appropriate parallelism strategy 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors - 3. Producing output directory + 3. Producing an output directory with a quantized checkpoint and a tokenizer The output directory produced is intended to be consumed by TensorRT-LLM toolbox - for efficient inference. This can be achieved using NeMo inference containers. + for efficient inference. This can be achieved using nemo.export.tensorrt_llm module. """ def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig): @@ -231,6 +235,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): def create_megatron_forward_loop( self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None ): + """Create a forward loop for over a given data iterator.""" from megatron.core.pipeline_parallel.schedules import get_forward_backward_func forward_backward_func = get_forward_backward_func() @@ -262,13 +267,13 @@ def loop(model): return loop def export(self, model: llm.GPTModel, model_dir: str) -> None: + """Export model to a TensorRT-LLM checkpoint.""" assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate # TODO: Support megatron_amp_O2 export_dir = self.export_config.path - use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or ( - model.config.pipeline_model_parallel_size > 1 - ) + + use_nfs_workspace = model.config.pipeline_model_parallel_size > 1 export_tensorrt_llm_checkpoint( model=get_unwrapped_mcore_model(model), decoder_type=self._get_decoder_type(model.config), @@ -278,15 +283,17 @@ def export(self, model: llm.GPTModel, model_dir: str) -> None: inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, use_nfs_workspace=use_nfs_workspace, ) + dist.barrier() # Save the model context in order to restore its tokenizer later. The destination # path is "nemo_context" as this name is used in nemo.export to setup tokenizer. - shutil.copytree( - os.path.join(model_dir, CONTEXT_PATH), - os.path.join(export_dir, "nemo_context"), - dirs_exist_ok=True, - ) - logging.info(f"Model context saved.") + if dist.get_rank() == 0: + shutil.copytree( + os.path.join(model_dir, CONTEXT_PATH), + os.path.join(export_dir, "nemo_context"), + dirs_exist_ok=True, + ) + logging.info("Model context saved.") logging.info(f"Export succeeded, model has been exported to {export_dir}.") @@ -294,7 +301,7 @@ def export(self, model: llm.GPTModel, model_dir: str) -> None: def get_calib_data_iter( data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512 ): - """Creates a sample data iterator for calibration""" + """Creates a sample data iterator for calibration.""" if data == "wikitext": dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") text_column = "text" @@ -314,6 +321,8 @@ def get_calib_data_iter( def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): + """Create a function that provides iterator over a given dataset.""" + def _iterator(): CHARACTERS_PER_TOKEN = 4 diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index c4c533fe38d0..bdfccb208d06 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -18,6 +18,7 @@ from nemo import lightning as nl from nemo.collections import llm +from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.utils import logging @@ -42,25 +43,44 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: return model_cfg -def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: +def load_with_modelopt_layer_spec( + nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1, inference_only: bool = True +): + # TODO: setting ddp="pytorch" with manually deleting model.optim is a hackish way to disable DDP initialization. Needs a systematic solution. + if inference_only: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=calib_tp, + pipeline_model_parallel_size=calib_pp, + pipeline_dtype=torch.bfloat16, + ckpt_load_optimizer=False, + ckpt_parallel_save_optim=False, + setup_optimizers=False, + lazy_init=True, + ddp="pytorch", + ) + else: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.bfloat16 + ) + trainer = nl.Trainer( devices=calib_tp, num_nodes=calib_pp, - strategy=nl.MegatronStrategy( - tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.bfloat16 - ), - plugins=nl.MegatronMixedPrecision(precision='bf16', pipeline_dtype=torch.bfloat16, autocast_enabled=True), + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision='bf16', params_dtype=torch.bfloat16, autocast_enabled=True), ) - fabric = trainer.to_fabric() - fabric.launch() - model_path = Path(nemo_checkpoint_path) - model = nl.io.load_context(ckpt_to_context_subdir(model_path)).model + model = nl.io.load_context(path=ckpt_to_context_subdir(model_path), subpath="model") model.config = quantizable_model_config(model.config) - return fabric.load_model(nemo_checkpoint_path, model=model) + + if inference_only: + del model.optim + + _setup_trainer_and_restore_model(nemo_checkpoint_path, trainer, model) + return model -def get_unwrapped_mcore_model(model: llm.GPTModel): +def get_unwrapped_mcore_model(model): from megatron.core.models.gpt import GPTModel as MCoreGPTModel unwrapped_model = model diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 8f772e3da5b7..e76729d5e31a 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -56,6 +56,7 @@ nemotron4_15b_16k, nemotron4_15b_64k, nemotron4_340b, + phi3_mini_4k_instruct, qwen2, qwen2_1p5b, qwen2_7b, @@ -111,6 +112,7 @@ "nemotron4_15b_16k", "nemotron4_15b_64k", "nemotron4_340b", + "phi3_mini_4k_instruct", "t5_220m", "t5_3b", "t5_11b", diff --git a/nemo/collections/llm/recipes/baichuan2_7b.py b/nemo/collections/llm/recipes/baichuan2_7b.py index 20de2c73f9dd..823f6e07cd57 100644 --- a/nemo/collections/llm/recipes/baichuan2_7b.py +++ b/nemo/collections/llm/recipes/baichuan2_7b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm import Baichuan2Config7B, Baichuan2Model diff --git a/nemo/collections/llm/recipes/chatglm3_6b.py b/nemo/collections/llm/recipes/chatglm3_6b.py index ef815a0851fc..b6c640372074 100644 --- a/nemo/collections/llm/recipes/chatglm3_6b.py +++ b/nemo/collections/llm/recipes/chatglm3_6b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm import ChatGLM3Config6B, ChatGLMModel diff --git a/nemo/collections/llm/recipes/finetune_default.py b/nemo/collections/llm/recipes/finetune_default.py index a060046a8bdf..f05fd7cb2d13 100644 --- a/nemo/collections/llm/recipes/finetune_default.py +++ b/nemo/collections/llm/recipes/finetune_default.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch import nemo.lightning as nl diff --git a/nemo/collections/llm/recipes/gemma2.py b/nemo/collections/llm/recipes/gemma2.py index 6fd1be83c183..2a690dc556d8 100644 --- a/nemo/collections/llm/recipes/gemma2.py +++ b/nemo/collections/llm/recipes/gemma2.py @@ -14,11 +14,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.gpt.model.gemma2 import Gemma2Config2B, Gemma2Config9B, Gemma2Config27B, Gemma2Model diff --git a/nemo/collections/llm/recipes/gemma2_27b.py b/nemo/collections/llm/recipes/gemma2_27b.py index 6f852f0fe6cf..2025bd570503 100644 --- a/nemo/collections/llm/recipes/gemma2_27b.py +++ b/nemo/collections/llm/recipes/gemma2_27b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/gemma2_2b.py b/nemo/collections/llm/recipes/gemma2_2b.py index 98c795591774..e1aa3ad4be86 100644 --- a/nemo/collections/llm/recipes/gemma2_2b.py +++ b/nemo/collections/llm/recipes/gemma2_2b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/gemma2_9b.py b/nemo/collections/llm/recipes/gemma2_9b.py index a211d8cfa838..8117102f1b75 100644 --- a/nemo/collections/llm/recipes/gemma2_9b.py +++ b/nemo/collections/llm/recipes/gemma2_9b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/gemma_2b.py b/nemo/collections/llm/recipes/gemma_2b.py index 8b2111e9f7c4..8798af436a9c 100644 --- a/nemo/collections/llm/recipes/gemma_2b.py +++ b/nemo/collections/llm/recipes/gemma_2b.py @@ -14,11 +14,11 @@ import os from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm import GemmaConfig2B, GemmaModel diff --git a/nemo/collections/llm/recipes/gemma_7b.py b/nemo/collections/llm/recipes/gemma_7b.py index 44efb3fe56b8..0bfd62b33e9e 100644 --- a/nemo/collections/llm/recipes/gemma_7b.py +++ b/nemo/collections/llm/recipes/gemma_7b.py @@ -14,11 +14,11 @@ import os from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm import GemmaConfig7B, GemmaModel diff --git a/nemo/collections/llm/recipes/gpt3_175b.py b/nemo/collections/llm/recipes/gpt3_175b.py index 5932ce5346b9..189f0ca6baf1 100644 --- a/nemo/collections/llm/recipes/gpt3_175b.py +++ b/nemo/collections/llm/recipes/gpt3_175b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index f5a52cd351be..d93b167b45b6 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -15,10 +15,10 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama31_405b.py b/nemo/collections/llm/recipes/llama31_405b.py index 31c83713b6e7..b1ef9719975c 100644 --- a/nemo/collections/llm/recipes/llama31_405b.py +++ b/nemo/collections/llm/recipes/llama31_405b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama31_70b.py b/nemo/collections/llm/recipes/llama31_70b.py index 91e4e10c83e6..6c0a108e10c7 100644 --- a/nemo/collections/llm/recipes/llama31_70b.py +++ b/nemo/collections/llm/recipes/llama31_70b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama31_8b.py b/nemo/collections/llm/recipes/llama31_8b.py index a4f0082e8535..d2cbd454d483 100644 --- a/nemo/collections/llm/recipes/llama31_8b.py +++ b/nemo/collections/llm/recipes/llama31_8b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index d43302a0a0ee..2b721b3f71c6 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_70b_16k.py b/nemo/collections/llm/recipes/llama3_70b_16k.py index 928f961f7cf3..0a394d386afd 100644 --- a/nemo/collections/llm/recipes/llama3_70b_16k.py +++ b/nemo/collections/llm/recipes/llama3_70b_16k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_70b_64k.py b/nemo/collections/llm/recipes/llama3_70b_64k.py index ffadf5ca8084..e035424d3506 100644 --- a/nemo/collections/llm/recipes/llama3_70b_64k.py +++ b/nemo/collections/llm/recipes/llama3_70b_64k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 4f6f6ce17443..13e2d13cc21a 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_8b_16k.py b/nemo/collections/llm/recipes/llama3_8b_16k.py index d6c1677a3b4b..b81d01c6ec9a 100644 --- a/nemo/collections/llm/recipes/llama3_8b_16k.py +++ b/nemo/collections/llm/recipes/llama3_8b_16k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_8b_64k.py b/nemo/collections/llm/recipes/llama3_8b_64k.py index 692347ea8dd0..ff176fb372bb 100644 --- a/nemo/collections/llm/recipes/llama3_8b_64k.py +++ b/nemo/collections/llm/recipes/llama3_8b_64k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/log/default.py b/nemo/collections/llm/recipes/log/default.py index d83580a1a543..023e4e459d5f 100644 --- a/nemo/collections/llm/recipes/log/default.py +++ b/nemo/collections/llm/recipes/log/default.py @@ -16,8 +16,8 @@ from datetime import timedelta from typing import Optional +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from nemo_run import Config, cli -from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from nemo import lightning as nl diff --git a/nemo/collections/llm/recipes/mamba2_130m.py b/nemo/collections/llm/recipes/mamba2_130m.py index 08640604a112..3f13f91f6609 100644 --- a/nemo/collections/llm/recipes/mamba2_130m.py +++ b/nemo/collections/llm/recipes/mamba2_130m.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/recipes/mamba2_1_3b.py b/nemo/collections/llm/recipes/mamba2_1_3b.py index 58eaf049b059..1a280b8b92a1 100644 --- a/nemo/collections/llm/recipes/mamba2_1_3b.py +++ b/nemo/collections/llm/recipes/mamba2_1_3b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/recipes/mamba2_2_7b.py b/nemo/collections/llm/recipes/mamba2_2_7b.py index 5cb37c6a02a5..0915cec748dd 100644 --- a/nemo/collections/llm/recipes/mamba2_2_7b.py +++ b/nemo/collections/llm/recipes/mamba2_2_7b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/recipes/mamba2_370m.py b/nemo/collections/llm/recipes/mamba2_370m.py index bb8bddc4045a..bb063dfcfc3f 100644 --- a/nemo/collections/llm/recipes/mamba2_370m.py +++ b/nemo/collections/llm/recipes/mamba2_370m.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/recipes/mamba2_780m.py b/nemo/collections/llm/recipes/mamba2_780m.py index 2f6ab6717ae1..e89905b2269a 100644 --- a/nemo/collections/llm/recipes/mamba2_780m.py +++ b/nemo/collections/llm/recipes/mamba2_780m.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/recipes/mamba2_8b.py b/nemo/collections/llm/recipes/mamba2_8b.py index 58883deba732..873d79fcb0f0 100644 --- a/nemo/collections/llm/recipes/mamba2_8b.py +++ b/nemo/collections/llm/recipes/mamba2_8b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/recipes/mamba2_hybrid_8b.py b/nemo/collections/llm/recipes/mamba2_hybrid_8b.py index eff37da46fca..09bb88a57089 100644 --- a/nemo/collections/llm/recipes/mamba2_hybrid_8b.py +++ b/nemo/collections/llm/recipes/mamba2_hybrid_8b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/recipes/mistral_7b.py b/nemo/collections/llm/recipes/mistral_7b.py index 7685bcd3ace6..3bc1e568185a 100644 --- a/nemo/collections/llm/recipes/mistral_7b.py +++ b/nemo/collections/llm/recipes/mistral_7b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/mistral_nemo_12b.py b/nemo/collections/llm/recipes/mistral_nemo_12b.py index e6616826d9a8..7d9fa1d792e9 100644 --- a/nemo/collections/llm/recipes/mistral_nemo_12b.py +++ b/nemo/collections/llm/recipes/mistral_nemo_12b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/mixtral_8x22b.py b/nemo/collections/llm/recipes/mixtral_8x22b.py index f768bf0499b1..16e6168e649b 100644 --- a/nemo/collections/llm/recipes/mixtral_8x22b.py +++ b/nemo/collections/llm/recipes/mixtral_8x22b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index d4286a15843f..5fbb0ac22c61 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_16k.py b/nemo/collections/llm/recipes/mixtral_8x7b_16k.py index 7cbfaf723544..499280cc8542 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b_16k.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b_16k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_64k.py b/nemo/collections/llm/recipes/mixtral_8x7b_64k.py index 3606be5ec12b..e0702f7b2a63 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b_64k.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b_64k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/nemotron.py b/nemo/collections/llm/recipes/nemotron.py index 104c3798567a..7982665eb3d5 100644 --- a/nemo/collections/llm/recipes/nemotron.py +++ b/nemo/collections/llm/recipes/nemotron.py @@ -14,11 +14,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.gpt.model.nemotron import ( diff --git a/nemo/collections/llm/recipes/nemotron3_22b.py b/nemo/collections/llm/recipes/nemotron3_22b.py index 724e21f002e3..2dd9c3ff5205 100644 --- a/nemo/collections/llm/recipes/nemotron3_22b.py +++ b/nemo/collections/llm/recipes/nemotron3_22b.py @@ -14,8 +14,8 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/nemotron3_22b_16k.py b/nemo/collections/llm/recipes/nemotron3_22b_16k.py index 81f4253ad37a..5ae58d1a757d 100644 --- a/nemo/collections/llm/recipes/nemotron3_22b_16k.py +++ b/nemo/collections/llm/recipes/nemotron3_22b_16k.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/nemotron3_22b_64k.py b/nemo/collections/llm/recipes/nemotron3_22b_64k.py index 676694697e4c..22f6291cfadb 100644 --- a/nemo/collections/llm/recipes/nemotron3_22b_64k.py +++ b/nemo/collections/llm/recipes/nemotron3_22b_64k.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/nemotron3_4b.py b/nemo/collections/llm/recipes/nemotron3_4b.py index e1c2ef345d7e..c208ee740265 100644 --- a/nemo/collections/llm/recipes/nemotron3_4b.py +++ b/nemo/collections/llm/recipes/nemotron3_4b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/nemotron3_8b.py b/nemo/collections/llm/recipes/nemotron3_8b.py index 202efe658d83..7799512c6260 100644 --- a/nemo/collections/llm/recipes/nemotron3_8b.py +++ b/nemo/collections/llm/recipes/nemotron3_8b.py @@ -14,8 +14,8 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo import lightning as nl diff --git a/nemo/collections/llm/recipes/nemotron4_15b.py b/nemo/collections/llm/recipes/nemotron4_15b.py index 0f15c47c67b9..ad0f884b0d3b 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b.py +++ b/nemo/collections/llm/recipes/nemotron4_15b.py @@ -14,8 +14,8 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/nemotron4_15b_16k.py b/nemo/collections/llm/recipes/nemotron4_15b_16k.py index 75eced72761f..e16c2b03b032 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b_16k.py +++ b/nemo/collections/llm/recipes/nemotron4_15b_16k.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/nemotron4_15b_64k.py b/nemo/collections/llm/recipes/nemotron4_15b_64k.py index 8286778aa7ba..2cedfbed398b 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b_64k.py +++ b/nemo/collections/llm/recipes/nemotron4_15b_64k.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/nemotron4_340b.py b/nemo/collections/llm/recipes/nemotron4_340b.py index c02950109669..b22abc43d558 100644 --- a/nemo/collections/llm/recipes/nemotron4_340b.py +++ b/nemo/collections/llm/recipes/nemotron4_340b.py @@ -14,8 +14,8 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo import lightning as nl diff --git a/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py b/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py new file mode 100644 index 000000000000..80ac60d054ce --- /dev/null +++ b/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py @@ -0,0 +1,284 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import lightning.pytorch as pl +import nemo_run as run +import torch +from lightning.pytorch.callbacks.callback import Callback +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.phi3mini import Phi3ConfigMini, Phi3Model +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "phi3_mini_4k_instruct" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Phi3 Mini 4k instruct model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Phi3 mini 4k instruct model. + + Examples: + CLI usage: + $ nemo llm pretrain model=phi3_mini_4k_instruct ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(Phi3Model, config=run.Config(Phi3ConfigMini)) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 1, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Phi3 mini 4k instruct model. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=phi3_mini_4k_instruct ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + For more information on distributed training strategies, refer to the + NeMo documentation on multi-GPU and multi-node training. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + tensor_parallelism: int = 1, + num_gpus_per_node: int = 1, + max_steps: int = 1168251, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for phi3_mini_4k_instruct model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory phi3_mini_4k_instruct + $ nemo llm pretrain --factory "phi3_mini_4k_instruct(num_nodes=1, name='my_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="phi3_mini_4k_instruct", num_nodes=1) + >>> print(recipe) + + Note: + For more details on pre-training LLMs with NeMo, see the pre-training + guide in the `examples/llm/pretrain/` directory. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 1, + tensor_parallelism: int = 1, + max_steps: int = 116825, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Phi3 mini-4k-instruct model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', + 'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory phi3_mini_4k_instruct + + Python API usage: + >>> recipe = finetune_recipe(name="phi3_mini_4k_instruct", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "microsoft/Phi-3-mini-4k-instruct", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 1 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.pad_to_max_length = True + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + return recipe diff --git a/nemo/collections/llm/recipes/qwen2.py b/nemo/collections/llm/recipes/qwen2.py index ff0c76a714f1..db9dcfc88865 100644 --- a/nemo/collections/llm/recipes/qwen2.py +++ b/nemo/collections/llm/recipes/qwen2.py @@ -14,10 +14,10 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.gpt.model.qwen2 import ( diff --git a/nemo/collections/llm/recipes/qwen2_1p5b.py b/nemo/collections/llm/recipes/qwen2_1p5b.py index 662f8e98899d..a3d705c4fb3a 100644 --- a/nemo/collections/llm/recipes/qwen2_1p5b.py +++ b/nemo/collections/llm/recipes/qwen2_1p5b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/qwen2_500m.py b/nemo/collections/llm/recipes/qwen2_500m.py index ac6cbfe84464..08541ca9e421 100644 --- a/nemo/collections/llm/recipes/qwen2_500m.py +++ b/nemo/collections/llm/recipes/qwen2_500m.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/qwen2_72b.py b/nemo/collections/llm/recipes/qwen2_72b.py index 0b94761e5749..c0bc9bf40611 100644 --- a/nemo/collections/llm/recipes/qwen2_72b.py +++ b/nemo/collections/llm/recipes/qwen2_72b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/qwen2_7b.py b/nemo/collections/llm/recipes/qwen2_7b.py index 10c990f15142..67bcc5e953bf 100644 --- a/nemo/collections/llm/recipes/qwen2_7b.py +++ b/nemo/collections/llm/recipes/qwen2_7b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/starcoder2.py b/nemo/collections/llm/recipes/starcoder2.py index c3a19326585c..b090ce1cf9ef 100644 --- a/nemo/collections/llm/recipes/starcoder2.py +++ b/nemo/collections/llm/recipes/starcoder2.py @@ -14,10 +14,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback + from nemo import lightning as nl from nemo.collections.llm.gpt.model.starcoder2 import ( Starcoder2Config3B, diff --git a/nemo/collections/llm/recipes/starcoder2_15b.py b/nemo/collections/llm/recipes/starcoder2_15b.py index a59ec272c865..14b53809111a 100644 --- a/nemo/collections/llm/recipes/starcoder2_15b.py +++ b/nemo/collections/llm/recipes/starcoder2_15b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/starcoder2_3b.py b/nemo/collections/llm/recipes/starcoder2_3b.py index 55884b353d8f..3ee81522ebc9 100644 --- a/nemo/collections/llm/recipes/starcoder2_3b.py +++ b/nemo/collections/llm/recipes/starcoder2_3b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/starcoder2_7b.py b/nemo/collections/llm/recipes/starcoder2_7b.py index 46e34b8b0c77..96b5ab36b876 100644 --- a/nemo/collections/llm/recipes/starcoder2_7b.py +++ b/nemo/collections/llm/recipes/starcoder2_7b.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/starcoder_15b.py b/nemo/collections/llm/recipes/starcoder_15b.py index cb0ba14df868..d87788be5613 100644 --- a/nemo/collections/llm/recipes/starcoder_15b.py +++ b/nemo/collections/llm/recipes/starcoder_15b.py @@ -14,10 +14,10 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/t5_11b.py b/nemo/collections/llm/recipes/t5_11b.py index b3806e6f2540..8baf54b4f42f 100644 --- a/nemo/collections/llm/recipes/t5_11b.py +++ b/nemo/collections/llm/recipes/t5_11b.py @@ -15,12 +15,12 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/t5_220m.py b/nemo/collections/llm/recipes/t5_220m.py index d59df213a3f4..27feb43837fb 100644 --- a/nemo/collections/llm/recipes/t5_220m.py +++ b/nemo/collections/llm/recipes/t5_220m.py @@ -15,12 +15,12 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/t5_3b.py b/nemo/collections/llm/recipes/t5_3b.py index e7f215d57635..333661d97117 100644 --- a/nemo/collections/llm/recipes/t5_3b.py +++ b/nemo/collections/llm/recipes/t5_3b.py @@ -15,12 +15,12 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/t5/data/fine_tuning.py b/nemo/collections/llm/t5/data/fine_tuning.py index 4180b4f135cb..ced4ea1a0b37 100644 --- a/nemo/collections/llm/t5/data/fine_tuning.py +++ b/nemo/collections/llm/t5/data/fine_tuning.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl from torch.utils.data import DataLoader from nemo.collections.llm.t5.data.core import create_sft_dataset diff --git a/nemo/collections/llm/t5/data/mock.py b/nemo/collections/llm/t5/data/mock.py index eaf41d290da4..31198a4446e9 100644 --- a/nemo/collections/llm/t5/data/mock.py +++ b/nemo/collections/llm/t5/data/mock.py @@ -14,10 +14,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset @@ -125,13 +125,11 @@ def __init__( self.seed = seed self.create_attention_mask = create_attention_mask - self.mask_encoder = torch.ones((self.seq_length, self.seq_length), device='cpu') - self.mask_decoder = torch.tril(torch.ones((self.seq_length_dec, self.seq_length_dec), device='cpu')) - self.mask_encoder_decoder = torch.ones((self.seq_length_dec, self.seq_length), device='cpu') + # update for T5 now use FlashFused attention (b11s) + self.mask_encoder = torch.ones(self.seq_length, device='cpu') + self.mask_decoder = torch.ones(self.seq_length_dec, device='cpu') self.mask_encoder = self.mask_encoder < 0.5 self.mask_decoder = self.mask_decoder < 0.5 - self.mask_encoder_decoder = self.mask_encoder_decoder < 0.5 - self.loss_mask = torch.ones(self.seq_length_dec, dtype=torch.float) def __len__(self) -> int: @@ -156,7 +154,6 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]: "truncated": 0, "enc_mask": self.mask_encoder, "dec_mask": self.mask_decoder, - "enc_dec_mask": self.mask_encoder_decoder, } return batch diff --git a/nemo/collections/llm/t5/data/pre_training.py b/nemo/collections/llm/t5/data/pre_training.py index 45d485ba2074..4bd6e5ed5e93 100644 --- a/nemo/collections/llm/t5/data/pre_training.py +++ b/nemo/collections/llm/t5/data/pre_training.py @@ -17,8 +17,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional -import pytorch_lightning as pl -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from nemo.lightning.data import WrappedDataLoader diff --git a/nemo/collections/llm/t5/model/t5.py b/nemo/collections/llm/t5/model/t5.py index 0b0e37814b4e..940c0e51ee92 100644 --- a/nemo/collections/llm/t5/model/t5.py +++ b/nemo/collections/llm/t5/model/t5.py @@ -16,11 +16,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union -import pytorch_lightning as L +import lightning.pytorch as L import torch import torch.distributed from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper +from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig @@ -38,8 +39,6 @@ HAVE_TE = False if TYPE_CHECKING: - from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model - from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -58,22 +57,32 @@ def t5_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: else: _batch = batch - # if Dataset object is NeMo 1.0's T5SFTDataset (e.g. when finetuning with SQUAD) - if 'enc_dec_mask' not in _batch: - encoder_attn_mask_3d = build_attention_mask_3d(_batch['enc_mask'], _batch['enc_mask'], AttnMaskType.padding) - decoder_attn_mask_3d = build_attention_mask_3d(_batch['dec_mask'], _batch['dec_mask'], AttnMaskType.causal) - enc_dec_attn_mask_3d = build_attention_mask_3d(_batch['dec_mask'], _batch['enc_mask'], AttnMaskType.padding) - _batch['enc_mask'] = encoder_attn_mask_3d - _batch['dec_mask'] = decoder_attn_mask_3d - _batch['enc_dec_mask'] = enc_dec_attn_mask_3d - - # if Dataset object is Mcore T5 dataset (e.g. pretraining) - else: - # convert attention mask values from int to True/False - _batch['enc_mask'] = _batch['enc_mask'] < 0.5 - _batch['dec_mask'] = _batch['dec_mask'] < 0.5 - _batch['enc_dec_mask'] = _batch['enc_dec_mask'] < 0.5 - + # work for both mcore's T5 pre-train dataset object, and NeMo's T5SFTDataset dataset + enc_mask = _batch['enc_mask'] < 0.5 + dec_mask = _batch['dec_mask'] < 0.5 + # process for Flash/Fused + enc_mask = enc_mask.unsqueeze(1).unsqueeze(1) + dec_mask = dec_mask.unsqueeze(1).unsqueeze(1) + enc_dec_mask = ( + dec_mask, + enc_mask, + ) + # set dec_mask to None because decoder uses AttnMaskType.causal + dec_mask = None + _batch['enc_mask'] = enc_mask + _batch['dec_mask'] = dec_mask + _batch['enc_dec_mask'] = enc_dec_mask + + # bring to device + for key in _batch.keys(): + if key == "enc_dec_mask": # because enc_dec_mask is a tuple + _batch[key] = (_batch[key][0].cuda(non_blocking=True), _batch[key][1].cuda(non_blocking=True)) + elif key == "dec_mask": # because dec_mask is a None since decoder uses AttnMaskType.causal + continue + else: + _batch[key] = _batch[key].cuda(non_blocking=True) + + # set up forward arguments for pipeline parallelism required_keys = set() required_keys.update(["enc_mask", "dec_mask", "enc_dec_mask"]) if parallel_state.is_pipeline_first_stage(): @@ -81,7 +90,7 @@ def t5_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: if parallel_state.is_pipeline_last_stage(): required_keys.update(("labels", "loss_mask")) - output = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()} + output = {key: val if key in required_keys else None for key, val in _batch.items()} return output @@ -139,6 +148,7 @@ class T5Config(TransformerConfig, io.IOMixin): share_embeddings_and_output_weights: bool = True make_vocab_size_divisible_by: int = 128 position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute" + apply_rope_fusion: bool = True max_position_embeddings: int = 512 rotary_percent: float = 1.0 seq_len_interpolation_factor: Optional[float] = None @@ -170,7 +180,6 @@ def configure_model(self, tokenizer) -> "MCoreT5Model": ) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages." from megatron.core import parallel_state - from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model encoder_config = copy.deepcopy(self) encoder_config.num_layers = self.encoder_num_layers diff --git a/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py b/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py index 1c39b1a72216..baead0c47962 100644 --- a/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py +++ b/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py @@ -15,8 +15,8 @@ from pathlib import Path import torch +from lightning.pytorch.utilities import rank_zero_only from PIL import Image -from pytorch_lightning.utilities import rank_zero_only from torch.utils.data import Dataset from tqdm import tqdm diff --git a/nemo/collections/multimodal/data/energon/base.py b/nemo/collections/multimodal/data/energon/base.py index 0a99b1a1baad..4e90dce55c7a 100644 --- a/nemo/collections/multimodal/data/energon/base.py +++ b/nemo/collections/multimodal/data/energon/base.py @@ -16,10 +16,10 @@ from typing import Any, Dict, Literal, Optional import fiddle as fdl -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from megatron.core import parallel_state from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader from typing_extensions import Self diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 5291497f92c3..5d19b8544305 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -22,8 +22,8 @@ import torch import torch.nn.functional as F from einops import rearrange, reduce, repeat +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPVisionModel, SiglipVisionModel from nemo.collections.common.parts.utils import extend_instance diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py index 158fa7595782..981600fcc3a1 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py @@ -18,9 +18,9 @@ import torch import torch.nn as nn from einops import rearrange, repeat +from lightning.pytorch import Trainer +from lightning.pytorch.utilities.rank_zero import rank_zero_only from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.rank_zero import rank_zero_only from torch._inductor import config as inductor_config from nemo.collections.multimodal.data.controlnet.controlnet_dataset import build_train_valid_datasets diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/util.py b/nemo/collections/multimodal/models/text_to_image/controlnet/util.py index 3d9a7d16b1c3..f890426c98f4 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/util.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/util.py @@ -17,9 +17,9 @@ import numpy as np import torch import torchvision +from lightning.pytorch import Callback +from lightning.pytorch.utilities.rank_zero import rank_zero_only from PIL import Image -from pytorch_lightning import Callback -from pytorch_lightning.utilities.rank_zero import rank_zero_only class ImageLogger(Callback): diff --git a/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py b/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py index 47548b02961d..8906263faeba 100644 --- a/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py +++ b/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py @@ -15,8 +15,8 @@ from typing import Any, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch._inductor import config as inductor_config from nemo.collections.multimodal.data.dreambooth.dreambooth_dataset import DreamBoothDataset diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py index ed9be58178c4..1772e465f604 100644 --- a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py +++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py @@ -17,8 +17,8 @@ from typing import Any import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer from nemo.collections.multimodal.data.imagen.imagen_dataset import build_train_valid_datasets from nemo.collections.multimodal.models.text_to_image.imagen.precond import ContinousDDPMPrecond, EDMPrecond diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py b/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py index 43660c9000a1..63963321fcf7 100644 --- a/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py +++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py @@ -17,8 +17,8 @@ from typing import Callable, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf -from pytorch_lightning import Trainer from torch.cuda.amp import autocast from nemo.collections.multimodal.models.text_to_image.imagen.imagen import Imagen, MegatronImagen @@ -73,7 +73,9 @@ def _load_model(model_ckpt: str, model_cfg: str, eval_mode: bool = True, trainer model_cfg.micro_batch_size = 1 model_cfg.global_batch_size = 1 model = MegatronImagen.restore_from( - restore_path=model_ckpt, override_config_path=model_cfg, trainer=trainer, + restore_path=model_ckpt, + override_config_path=model_cfg, + trainer=trainer, ) elif model_ckpt.endswith('.ckpt'): model_cfg = OmegaConf.load(model_cfg) @@ -128,7 +130,9 @@ def model_cfg_modifier(model_cfg): models = [] print('Load base model.') model = ImagenPipeline._load_model( - model_ckpt=customized_models.base_ckpt, model_cfg=customized_models.base_cfg, trainer=trainer, + model_ckpt=customized_models.base_ckpt, + model_cfg=customized_models.base_cfg, + trainer=trainer, ) models.append(model) diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py index 8b18fe2b25fe..c7e8795a749c 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py @@ -17,14 +17,14 @@ from typing import Any, Dict, List, Tuple, Union import hydra -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch._dynamo import torch.nn as nn from einops import rearrange +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.utilities import rank_zero_only from safetensors.torch import load_file as load_safetensors from torch._dynamo import optimize from torch.optim.lr_scheduler import LambdaLR diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py index d79d85c2e026..311ebc0f06f5 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn.functional as F from nemo.utils import logging diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py index 744dc6945394..163b2fb27e0f 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py @@ -17,18 +17,18 @@ from functools import partial from typing import Any, Dict, List, Optional, Union +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch import torch.nn as nn from einops import rearrange, repeat -from lightning_fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.pytorch import Trainer +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml +from lightning.pytorch.utilities.migration import pl_legacy_patch +from lightning.pytorch.utilities.rank_zero import rank_zero_only from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml -from pytorch_lightning.utilities.migration import pl_legacy_patch -from pytorch_lightning.utilities.rank_zero import rank_zero_only from torch._inductor import config as inductor_config from torchvision.utils import make_grid from tqdm import tqdm diff --git a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py index a9e51610bedd..84718f99262f 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py @@ -23,10 +23,10 @@ import numpy as np import torch import torch.nn.functional as F +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.trainer.trainer import Trainer from tqdm import tqdm from nemo.collections.multimodal.data.clip.clip_dataset import ( diff --git a/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py index 79c0f3910be0..37e33f892890 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py @@ -19,11 +19,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.trainer.trainer import Trainer from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.multimodal.data.clip.clip_dataset import tokenize from nemo.collections.multimodal.data.nsfw.nsfw_dataset import build_dataset @@ -38,7 +38,6 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging - try: from megatron.core.num_microbatches_calculator import get_num_microbatches diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 6ba2e8ca91f9..8773b47025bc 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -17,10 +17,10 @@ import numpy as np import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf import DictConfig, OmegaConf, open_dict from PIL import Image -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from transformers import CLIPImageProcessor, SiglipImageProcessor from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform diff --git a/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py b/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py index 50b4d29c05a4..53ae4a2dfb65 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py @@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset diff --git a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_bpe_models.py index 106fbc432926..8249e5d8a7f8 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_bpe_models.py @@ -17,8 +17,8 @@ from typing import Dict, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.rnnt import RNNTLoss @@ -90,7 +90,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup decoding object self.decoding = RNNTBPEDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) # Setup wer object @@ -282,7 +285,10 @@ def change_vocabulary( decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) self.wer = WER( @@ -388,7 +394,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) self.wer = WER( diff --git a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py index 1b30263985da..158bfaddcc96 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py @@ -19,8 +19,8 @@ from typing import List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from tqdm.auto import tqdm from nemo.collections.asr.losses.ctc import CTCLoss diff --git a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_bpe_models.py b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_bpe_models.py index eeffb906981a..11e9d43e1737 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_bpe_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_bpe_models.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.losses.rnnt import RNNTLoss from nemo.collections.asr.metrics.wer import WER @@ -68,7 +68,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup decoding object self.decoding = RNNTBPEDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) # Setup wer object @@ -165,7 +168,10 @@ def change_vocabulary( decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) self.wer = WER( @@ -214,7 +220,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) self.wer = WER( diff --git a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py index 5a86eed93019..75202238d2d0 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py @@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Tuple, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 46b2ca3e26fd..aab27cf2d908 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -21,11 +21,11 @@ import sacrebleu import torch from hydra.utils import get_class +from lightning.pytorch.trainer.trainer import Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import ListConfig from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.utilities import rank_zero_only from nemo.collections.asr.models import ASRModel, EncDecSpeakerLabelModel from nemo.collections.asr.parts.utils.eval_utils import remove_punctuations diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 79fc0468e819..a99f5c346831 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -21,10 +21,10 @@ import sacrebleu import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import ListConfig from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.asr.models import ASRModel, SpeechEncDecSelfSupervisedModel from nemo.collections.common.data.utils import move_data_to_device diff --git a/nemo/collections/multimodal_autoregressive/data/README.md b/nemo/collections/multimodal_autoregressive/data/README.md index c4814ad267f8..3f6d5a6c6a81 100644 --- a/nemo/collections/multimodal_autoregressive/data/README.md +++ b/nemo/collections/multimodal_autoregressive/data/README.md @@ -8,27 +8,7 @@ This is an example of how to do autoregressive generation for multiple modalitie ### 1. Vision Understanding using EMU3 Tokenizer #### Download and Extract data -We will be working with coyo dataset which has 700 million images. - -First create credentials for rclone . Create this file at `~/.config/rclone/rclone.conf` -``` -[pbss-team-vfm-share-ro-s3] -type = s3 -env_auth = true -access_key_id = -secret_access_key = -region = us-east-1 -endpoint = https://pdx.s8k.io -``` -To download the images -``` -rclone copy pbss-team-vfm-share-ro-s3:webdataset_images/webdataset_edify_image_v3/coyo_700m/resolution_lt_720/aspect_ratio_16_9/images images --transfers=16 --multi-thread-streams=16 --checkers=8 -P --stats 5s -``` - -To download the captions -``` -rclone copy pbss-team-vfm-share-ro-s3:webdataset_images/webdataset_edify_image_v3/coyo_700m/resolution_lt_720/aspect_ratio_16_9/captions_ai_v3p1 captions_ai_v3p1 --transfers=16 --multi-thread-streams=16 --checkers=8 -P --stats 5s -``` +Download the [COYO700M dataset](https://github.com/kakaobrain/coyo-dataset) Once downloaded extract the data using tar utilities. @@ -70,13 +50,13 @@ Follow usual nemo instructions to train any autoregressive model. ``` #### Inference -To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference.yaml) +To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml) *NOTE* Make sure you have a .nemo file (checkpoint). If you just have a regular megatron checkpoint you have to do a conversion as shown in [this doc](https://docs.nvidia.com/nemo-framework/user-guide/latest/llms/gpt/checkpointconversion.html?highlight=convert) Run inference as follows ``` -torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval.py +torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py ``` @@ -116,13 +96,11 @@ Follow usual nemo instructions to train any autoregressive model. ``` #### Inference -To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference.yaml) +To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml) *NOTE* Make sure you have a .nemo file (checkpoint). If you just have a regular megatron checkpoint you have to do a conversion as shown in [this doc](https://docs.nvidia.com/nemo-framework/user-guide/latest/llms/gpt/checkpointconversion.html?highlight=convert) Run inference as follows ``` -torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval.py -``` - -TODO : Instructions to convert visual tokens to images coming soon. \ No newline at end of file +torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py +``` \ No newline at end of file diff --git a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py index bbd14f47a651..ea5f8c5a930b 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import omegaconf import torch from nemo.collections.nlp.modules.common import VirtualPromptSource @@ -70,8 +71,55 @@ def __init__( # Datasets are a list of file path strings to .json or .jsonl files elif isinstance(datasets[0], str): for path in datasets: - dataset = open(path, 'r', encoding='utf-8') - self.load_data(dataset) + with open(path, 'r', encoding='utf-8') as dataset: + dataset_examples = self.load_data(dataset) + self.examples.extend(dataset_examples) + elif isinstance(datasets[0], omegaconf.ListConfig) or isinstance(datasets[0], list): + # Dataset is a list of tuples with the first element being the probability of sampling from the dataset + # This code repeates the smaller datasets to approximately match the target probabilities + total_examples = 0 + dataset_lengths = [] + target_probs = [] + datasets_examples_list = [] + for prob_and_path in datasets: + prob = prob_and_path[0] + path = prob_and_path[1] + with open(path, 'r', encoding='utf-8') as dataset: + dataset_examples = self.load_data(dataset) + datasets_examples_list.append(dataset_examples) + dataset_lengths.append(len(dataset_examples)) + total_examples += len(dataset_examples) + target_probs.append(prob) + + # Normalize the target probs + target_probs = [prob / sum(target_probs) for prob in target_probs] + current_probs = [dataset_lengths[i] / total_examples for i in range(len(dataset_lengths))] + + # Increase number of examples needed without reducing the larger datasets with low target probs + new_total_examples = total_examples + for dataset_idx in range(len(datasets)): + if target_probs[dataset_idx] < current_probs[dataset_idx]: + target_total_examples = int(dataset_lengths[dataset_idx] / target_probs[dataset_idx]) + new_total_examples = max(new_total_examples, target_total_examples) + + final_total_examples = 0 + final_dataset_lengths = [] + for dataset_idx in range(len(datasets)): + num_samples_required = int(new_total_examples * target_probs[dataset_idx]) + num_repeat = max( + int(round(num_samples_required // dataset_lengths[dataset_idx])), 1 + ) # At least 1 repeat + logging.info("dataset idx {}, num_repeat {}".format(dataset_idx, num_repeat)) + dataset_examples_repeated = datasets_examples_list[dataset_idx] * num_repeat + final_dataset_lengths.append(len(dataset_examples_repeated)) + final_total_examples += len(dataset_examples_repeated) + self.examples.extend(dataset_examples_repeated) + + final_probs = [final_dataset_lengths[i] / final_total_examples for i in range(len(final_dataset_lengths))] + logging.info("Target probs: {}".format(target_probs)) + logging.info("Final probs: {}".format(final_probs)) + logging.info("Initial total examples: {}".format(total_examples)) + logging.info("Final total examples: {}".format(final_total_examples)) else: raise ValueError("Datasets must be a list of dicts or a list of filepath strings") diff --git a/nemo/collections/nlp/data/machine_translation/preproc_mt_data.py b/nemo/collections/nlp/data/machine_translation/preproc_mt_data.py index b95993ded69e..59181d8cb89f 100644 --- a/nemo/collections/nlp/data/machine_translation/preproc_mt_data.py +++ b/nemo/collections/nlp/data/machine_translation/preproc_mt_data.py @@ -21,8 +21,8 @@ import tempfile from joblib import Parallel, delayed +from lightning.pytorch import Trainer from omegaconf import ListConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model from nemo.collections.nlp.data.language_modeling.sentence_dataset import SentenceDataset @@ -33,23 +33,23 @@ class MTDataPreproc: - """ Automatically trains tokenizers and preprocesses machine translation data based on the MTEncDecModelConfig. - For training NMT models with datasets larger than 5M sentence pairs, - it can be inefficient to train them without first creating a tarred dataset. - If the user wants to change the tokenizer, vocab size, or batch size, for example, - they must reprocess the data with the correct configuration. - With MTDataPreproc users can sweep through data configurations and the tarred dataset will - be automatically created according to the model configuration. - To train tokenizer model and create tarred dataset specify in configuration: - model.preproc_out_dir=/path/to/preproc_out - model.encoder_tokenizer.vocab_size=32000 - model.decoder_tokenizer.vocab_size=32000 - model.train_ds.use_tarred_dataset=True - model.train_ds.src_file_name=/path/to/src.txt - model.train_ds.tgt_file_name=/path/to/tgt.txt - model.train_ds.tokens_in_batch=16000 - Once a dataset has been constructed based on this configuration, MTDataPreproc will not process it again. - If a previously trained tokenizer model or tarred dataset is found, MTDataPreproc will not preprocess the data. + """Automatically trains tokenizers and preprocesses machine translation data based on the MTEncDecModelConfig. + For training NMT models with datasets larger than 5M sentence pairs, + it can be inefficient to train them without first creating a tarred dataset. + If the user wants to change the tokenizer, vocab size, or batch size, for example, + they must reprocess the data with the correct configuration. + With MTDataPreproc users can sweep through data configurations and the tarred dataset will + be automatically created according to the model configuration. + To train tokenizer model and create tarred dataset specify in configuration: + model.preproc_out_dir=/path/to/preproc_out + model.encoder_tokenizer.vocab_size=32000 + model.decoder_tokenizer.vocab_size=32000 + model.train_ds.use_tarred_dataset=True + model.train_ds.src_file_name=/path/to/src.txt + model.train_ds.tgt_file_name=/path/to/tgt.txt + model.train_ds.tokens_in_batch=16000 + Once a dataset has been constructed based on this configuration, MTDataPreproc will not process it again. + If a previously trained tokenizer model or tarred dataset is found, MTDataPreproc will not preprocess the data. """ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None) -> None: @@ -147,12 +147,16 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None) -> None: global_rank=self.global_rank, encoder_training_sample_size=cfg.encoder_tokenizer.get('training_sample_size', -1), decoder_training_sample_size=cfg.decoder_tokenizer.get('training_sample_size', -1), - encoder_special_tokens=OmegaConf.to_container(cfg.encoder_tokenizer.special_tokens) - if cfg.encoder_tokenizer.special_tokens - else None, - decoder_special_tokens=OmegaConf.to_container(cfg.decoder_tokenizer.special_tokens) - if cfg.decoder_tokenizer.special_tokens - else None, + encoder_special_tokens=( + OmegaConf.to_container(cfg.encoder_tokenizer.special_tokens) + if cfg.encoder_tokenizer.special_tokens + else None + ), + decoder_special_tokens=( + OmegaConf.to_container(cfg.decoder_tokenizer.special_tokens) + if cfg.decoder_tokenizer.special_tokens + else None + ), spt_symbols=spt_symbols, ) # update config @@ -280,10 +284,10 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None) -> None: ) def tar_files_to_string(self, tar_files): - """ Tar files are generated in the following format: basename.number.tar + """Tar files are generated in the following format: basename.number.tar Where number is an integer from 1 to the number of tar files. We convert this list to a string that can be used in the model config to specify - tarred datasets: basename_OP_1..num_tar_files_CL_.tar + tarred datasets: basename_OP_1..num_tar_files_CL_.tar Args: tar_files (List[str]): List of tar files generated by preprocess_parallel_dataset @@ -337,7 +341,9 @@ def get_enc_dec_tokenizers( @staticmethod def get_monolingual_tokenizer( - tokenizer_name=None, tokenizer_model=None, bpe_dropout=0.0, + tokenizer_name=None, + tokenizer_model=None, + bpe_dropout=0.0, ): if tokenizer_name == 'sentencepiece': tokenizer = SentencePieceTokenizer(model_path=tokenizer_model) @@ -385,14 +391,14 @@ def preprocess_parallel_dataset( src_fname (str): path to source text data tgt_fname (str): path to target text data out_dir (str): path to write tarred dataset - encoder_tokenizer (Any): tokenizer for encoder + encoder_tokenizer (Any): tokenizer for encoder decoder_tokenizer (Any): tokenizer for decoder - max_seq_length (int): maximum sequence length - min_seq_length (int): minimum sequence length - tokens_in_batch (int): tokens per batch per GPU, effectively batch size + max_seq_length (int): maximum sequence length + min_seq_length (int): minimum sequence length + tokens_in_batch (int): tokens per batch per GPU, effectively batch size lines_per_dataset_fragment (int): number of lines to consider for bucketing and padding num_batches_per_tarfile (int): number of batches (pickle files) within each tarfile - tar_file_prefix (str) : add string prefix to tar files + tar_file_prefix (str) : add string prefix to tar files n_jobs (int): number of processes to use for data processing (-2 to use all but 2) """ @@ -471,7 +477,10 @@ def preprocess_parallel_dataset( out_dir, f'remainder-batches.tokens.{tokens_in_batch}.tar_file_{remainder_tar_file_ctr}.tar', ) - remainder_tar_file_ptr = tarfile.open(remainder_tar_file_path, 'w',) + remainder_tar_file_ptr = tarfile.open( + remainder_tar_file_path, + 'w', + ) batch_in_tar_ctr = 0 tar_file_ptr.close() os.remove(tar_file_path) @@ -631,9 +640,9 @@ def preprocess_monolingual_dataset( fname (str): Path to source text data out_dir (str): Path to write tarred dataset tokenizer (Any): Path to tokenizer model - max_seq_length (int): maximum sequence length - min_seq_length (int): minimum sequence length - tokens_in_batch (int): tokens per batch per GPU, effectively batch size + max_seq_length (int): maximum sequence length + min_seq_length (int): minimum sequence length + tokens_in_batch (int): tokens per batch per GPU, effectively batch size lines_per_dataset_fragment (int): number of lines to consider for bucketing and padding num_batches_per_tarfile (int): number of batches (pickle files) within each tarfile global_rank (int): if set to zero, data will be processed on this node @@ -808,7 +817,8 @@ def train_tokenizers( split_by_whitespace=split_by_whitespace, ) os.rename( - os.path.join(out_dir, 'tokenizer.model'), encoder_tokenizer_model, + os.path.join(out_dir, 'tokenizer.model'), + encoder_tokenizer_model, ) else: if encoder_tokenizer_name in supported_train_tokenizers: @@ -1007,7 +1017,10 @@ def write_parallel_batches_to_tarfiles( tar_file_path = os.path.join( out_dir, 'fragment-%s-batches.tokens.%d.%d.tar' % (fragment_index, num_tokens, tar_file_ctr) ) - tar_file_ptr = tarfile.open(tar_file_path, 'w',) + tar_file_ptr = tarfile.open( + tar_file_path, + 'w', + ) batch_ctr = 0 # return tar files paths that have batches remaining diff --git a/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py b/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py index 07ca790866c7..6c7472b95c42 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py @@ -21,8 +21,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from transformers import AutoModelWithLMHead diff --git a/nemo/collections/nlp/models/dialogue/dialogue_gpt_generation_model.py b/nemo/collections/nlp/models/dialogue/dialogue_gpt_generation_model.py index 116605b65d52..7fb0ba770189 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_gpt_generation_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_gpt_generation_model.py @@ -19,8 +19,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from transformers import AutoModelWithLMHead diff --git a/nemo/collections/nlp/models/dialogue/dialogue_nearest_neighbour_model.py b/nemo/collections/nlp/models/dialogue/dialogue_nearest_neighbour_model.py index 29e2627fa038..9bf7ae2a9116 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_nearest_neighbour_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_nearest_neighbour_model.py @@ -19,8 +19,8 @@ import numpy as np import torch import torch.nn.functional as F +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoModel from nemo.collections.nlp.data.dialogue import DialogueSGDDataProcessor diff --git a/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py b/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py index 48f3e5127a88..3f0d09d7dc66 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py @@ -18,8 +18,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from transformers import AutoModelForSeq2SeqLM diff --git a/nemo/collections/nlp/models/dialogue/dialogue_zero_shot_intent_model.py b/nemo/collections/nlp/models/dialogue/dialogue_zero_shot_intent_model.py index 5298c060df08..1df19cf8a556 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_zero_shot_intent_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_zero_shot_intent_model.py @@ -19,8 +19,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoModelForSequenceClassification, AutoTokenizer from nemo.collections.nlp.data.dialogue import DialogueSGDDataProcessor diff --git a/nemo/collections/nlp/models/dialogue/intent_slot_classification_model.py b/nemo/collections/nlp/models/dialogue/intent_slot_classification_model.py index 777d468084e2..09a81b33c973 100644 --- a/nemo/collections/nlp/models/dialogue/intent_slot_classification_model.py +++ b/nemo/collections/nlp/models/dialogue/intent_slot_classification_model.py @@ -16,8 +16,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss diff --git a/nemo/collections/nlp/models/dialogue/sgdqa_model.py b/nemo/collections/nlp/models/dialogue/sgdqa_model.py index 3b30dfccd9ce..6cd2243423a4 100644 --- a/nemo/collections/nlp/models/dialogue/sgdqa_model.py +++ b/nemo/collections/nlp/models/dialogue/sgdqa_model.py @@ -22,8 +22,8 @@ from typing import List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.nlp.data.dialogue import DialogueSGDBERTDataset, DialogueSGDDataProcessor diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py index 7d4cac46cc28..253962e55621 100644 --- a/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py @@ -19,8 +19,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq import nemo.collections.nlp.data.text_normalization.constants as constants @@ -307,7 +307,7 @@ def _infer( span_ends: List[List[int]], inst_directions: List[str], ): - """ Main function for Inference + """Main function for Inference Args: sents: A list of inputs tokenized by a basic tokenizer. nb_spans: A list of ints where each int indicates the number of semiotic spans in each input. @@ -521,9 +521,9 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, data_split: str): tokenizer_name=self.transformer_name, mode=self.mode, max_len=self.max_sequence_len, - decoder_data_augmentation=cfg.get('decoder_data_augmentation', False) - if data_split == "train" - else False, + decoder_data_augmentation=( + cfg.get('decoder_data_augmentation', False) if data_split == "train" else False + ), lang=self.lang, use_cache=cfg.get('use_cache', False), max_insts=cfg.get('max_insts', -1), diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py index feeda99bdbe5..1ce005403999 100644 --- a/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py @@ -16,8 +16,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch import nn from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification from transformers.tokenization_utils_base import BatchEncoding @@ -151,7 +151,7 @@ def on_test_epoch_end(self): # Functions for inference @torch.no_grad() def _infer(self, sents: List[List[str]], inst_directions: List[str]): - """ Main function for Inference + """Main function for Inference Args: sents: A list of inputs tokenized by a basic tokenizer. @@ -248,7 +248,7 @@ def _infer(self, sents: List[List[str]], inst_directions: List[str]): return all_tag_preds, nb_spans, span_starts, span_ends def _postprocess_tag_preds(self, words: List[str], inst_dir: str, preds: List[str]): - """ Function for postprocessing the raw tag predictions of the model. It + """Function for postprocessing the raw tag predictions of the model. It corrects obvious mistakes in the tag predictions such as a TRANSFORM span starts with I_TRANSFORM_TAG (instead of B_TRANSFORM_TAG). @@ -280,7 +280,7 @@ def _postprocess_tag_preds(self, words: List[str], inst_dir: str, preds: List[st return final_preds def decode_tag_preds(self, tag_preds: List[List[str]]): - """ Decoding the raw tag predictions to locate the semiotic spans in the + """Decoding the raw tag predictions to locate the semiotic spans in the input texts. Args: diff --git a/nemo/collections/nlp/models/enc_dec_nlp_model.py b/nemo/collections/nlp/models/enc_dec_nlp_model.py index d9aa3c017bae..60c6b616c20a 100644 --- a/nemo/collections/nlp/models/enc_dec_nlp_model.py +++ b/nemo/collections/nlp/models/enc_dec_nlp_model.py @@ -15,8 +15,8 @@ from dataclasses import dataclass from typing import Any +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.omegaconf import MISSING -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.nlp_model import NLPModel from nemo.collections.nlp.modules.common.decoder_module import DecoderModule @@ -35,8 +35,7 @@ class EncDecNLPModelConfig(ModelConfig): class EncDecNLPModel(NLPModel): - """Base class for encoder-decoder NLP models. - """ + """Base class for encoder-decoder NLP models.""" def __init__(self, cfg: EncDecNLPModelConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) diff --git a/nemo/collections/nlp/models/entity_linking/entity_linking_model.py b/nemo/collections/nlp/models/entity_linking/entity_linking_model.py index 4afae81e3893..640520cdaaa7 100644 --- a/nemo/collections/nlp/models/entity_linking/entity_linking_model.py +++ b/nemo/collections/nlp/models/entity_linking/entity_linking_model.py @@ -15,8 +15,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoTokenizer from nemo.collections.common.losses import MultiSimilarityLoss diff --git a/nemo/collections/nlp/models/glue_benchmark/glue_benchmark_model.py b/nemo/collections/nlp/models/glue_benchmark/glue_benchmark_model.py index 4447ebb89386..e90cf9d88c30 100644 --- a/nemo/collections/nlp/models/glue_benchmark/glue_benchmark_model.py +++ b/nemo/collections/nlp/models/glue_benchmark/glue_benchmark_model.py @@ -19,8 +19,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss, MSELoss from nemo.collections.nlp.data.glue_benchmark.glue_benchmark_dataset import GLUE_TASKS_NUM_LABELS, GLUEDataset diff --git a/nemo/collections/nlp/models/information_retrieval/base_ir_model.py b/nemo/collections/nlp/models/information_retrieval/base_ir_model.py index 67424320d185..91d86fef1851 100644 --- a/nemo/collections/nlp/models/information_retrieval/base_ir_model.py +++ b/nemo/collections/nlp/models/information_retrieval/base_ir_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.data import BertInformationRetrievalDataset from nemo.collections.nlp.models.nlp_model import NLPModel diff --git a/nemo/collections/nlp/models/information_retrieval/bert_dpr_model.py b/nemo/collections/nlp/models/information_retrieval/bert_dpr_model.py index 03b62d91170c..bfbec123d13e 100644 --- a/nemo/collections/nlp/models/information_retrieval/bert_dpr_model.py +++ b/nemo/collections/nlp/models/information_retrieval/bert_dpr_model.py @@ -15,8 +15,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import SmoothedCrossEntropyLoss from nemo.collections.nlp.data import BertInformationRetrievalDataset @@ -63,29 +63,50 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): @typecheck() def forward( - self, q_input_ids, q_token_type_ids, q_attention_mask, p_input_ids, p_token_type_ids, p_attention_mask, + self, + q_input_ids, + q_token_type_ids, + q_attention_mask, + p_input_ids, + p_token_type_ids, + p_attention_mask, ): q_vectors = self.q_encoder( - input_ids=q_input_ids, token_type_ids=q_token_type_ids, attention_mask=q_attention_mask, + input_ids=q_input_ids, + token_type_ids=q_token_type_ids, + attention_mask=q_attention_mask, ) q_vectors = q_vectors[:, 0] batch_size, hidden_size = q_vectors.size() p_vectors = self.p_encoder( - input_ids=p_input_ids, token_type_ids=p_token_type_ids, attention_mask=p_attention_mask, + input_ids=p_input_ids, + token_type_ids=p_token_type_ids, + attention_mask=p_attention_mask, ) num_passages = p_vectors.shape[0] // batch_size p_vectors = p_vectors[:, 0].view(-1, num_passages, hidden_size) p_positives, p_negatives = p_vectors[:, 0], p_vectors[:, 1:] scores = torch.cat( - (torch.matmul(q_vectors, p_positives.T), torch.einsum("ij,ipj->ip", q_vectors, p_negatives),), dim=1, + ( + torch.matmul(q_vectors, p_positives.T), + torch.einsum("ij,ipj->ip", q_vectors, p_negatives), + ), + dim=1, ) return scores def compute_scores_and_loss(self, inputs): - (q_input_ids, q_input_mask, q_input_type_ids, p_input_ids, p_input_mask, p_input_type_ids,) = inputs + ( + q_input_ids, + q_input_mask, + q_input_type_ids, + p_input_ids, + p_input_mask, + p_input_type_ids, + ) = inputs batch_size, num_passages, p_seq_length = p_input_ids.size() q_seq_length = q_input_ids.size()[-1] @@ -100,10 +121,17 @@ def compute_scores_and_loss(self, inputs): normalized_scores = torch.log_softmax(scores, dim=-1) labels = torch.arange(batch_size)[:, None].long().to(normalized_scores.device) - loss = self.loss(log_probs=normalized_scores, labels=labels, output_mask=torch.ones_like(labels),) + loss = self.loss( + log_probs=normalized_scores, + labels=labels, + output_mask=torch.ones_like(labels), + ) scores = scores[:, 0] - scores = torch.cat((torch.diag(scores)[:, None], scores[:, batch_size:]), dim=1,) + scores = torch.cat( + (torch.diag(scores)[:, None], scores[:, batch_size:]), + dim=1, + ) return scores, loss diff --git a/nemo/collections/nlp/models/information_retrieval/bert_joint_ir_model.py b/nemo/collections/nlp/models/information_retrieval/bert_joint_ir_model.py index a4dc4356342a..33885e6b50c6 100644 --- a/nemo/collections/nlp/models/information_retrieval/bert_joint_ir_model.py +++ b/nemo/collections/nlp/models/information_retrieval/bert_joint_ir_model.py @@ -15,8 +15,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import SmoothedCrossEntropyLoss from nemo.collections.nlp.models.information_retrieval.base_ir_model import BaseIRModel @@ -53,7 +53,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.bert_model = self.get_lm_model_with_padded_embedding(cfg) hidden_size = self.bert_model.config.hidden_size self.sim_score_regressor = SequenceRegression( - hidden_size=hidden_size, num_layers=1, dropout=cfg.language_model.sim_score_dropout, + hidden_size=hidden_size, + num_layers=1, + dropout=cfg.language_model.sim_score_dropout, ) self.loss = SmoothedCrossEntropyLoss(pad_id=self.tokenizer.pad_id) @@ -61,7 +63,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): def forward(self, input_ids, attention_mask, token_type_ids): hidden_states = self.bert_model( - input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, ) if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py index 5e38b61938c9..a5b71d5bcb69 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py @@ -16,13 +16,11 @@ import os import numpy as np - - import torch +from lightning.pytorch.trainer.trainer import Trainer from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from torch.distributed import all_gather as all_gather_no_backprop from torch.distributed.nn.functional import all_gather as all_gather_with_backprop @@ -46,7 +44,6 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import logging - try: from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index c7565f45358e..b5240ec2e170 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.information_retrieval.gpt_embedding_dataset import GPTEmbeddingDataset from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py index e316871fe607..fa593adf5c8f 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.information_retrieval.gpt_embedding_dataset import GPTRerankerDataset from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( diff --git a/nemo/collections/nlp/models/intent_slot_classification/intent_slot_classification_model.py b/nemo/collections/nlp/models/intent_slot_classification/intent_slot_classification_model.py index 0cd1d07af5dd..a49bc699ab24 100644 --- a/nemo/collections/nlp/models/intent_slot_classification/intent_slot_classification_model.py +++ b/nemo/collections/nlp/models/intent_slot_classification/intent_slot_classification_model.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss @@ -38,8 +38,7 @@ class IntentSlotClassificationModel(NLPModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): - """ Initializes BERT Joint Intent and Slot model. - """ + """Initializes BERT Joint Intent and Slot model.""" self.max_seq_length = cfg.language_model.max_seq_length # init superclass # Check the presence of data_dir. @@ -75,7 +74,7 @@ def _set_defaults_data_desc(self, cfg): OmegaConf.set_struct(cfg, True) def _set_data_desc_to_cfg(self, cfg, data_dir, train_ds, validation_ds): - """ Method creates IntentSlotDataDesc and copies generated values to cfg.data_desc. """ + """Method creates IntentSlotDataDesc and copies generated values to cfg.data_desc.""" # Save data from data desc to config - so it can be reused later, e.g. in inference. data_desc = IntentSlotDataDesc(data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix]) OmegaConf.set_struct(cfg, False) @@ -109,7 +108,7 @@ def _set_data_desc_to_cfg(self, cfg, data_dir, train_ds, validation_ds): OmegaConf.set_struct(cfg, True) def _save_label_ids(self, label_ids: Dict[str, int], filename: str) -> None: - """ Saves label ids map to a file """ + """Saves label ids map to a file""" with open(filename, 'w') as out: labels, _ = zip(*sorted(label_ids.items(), key=lambda x: x[1])) out.write('\n'.join(labels)) @@ -117,7 +116,7 @@ def _save_label_ids(self, label_ids: Dict[str, int], filename: str) -> None: logging.info(f'Labels mapping saved to : {out.name}') def _reconfigure_classifier(self): - """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """ + """Method reconfigures the classifier depending on the settings of model cfg.data_desc""" self.classifier = SequenceTokenClassifier( hidden_size=self.hidden_size, diff --git a/nemo/collections/nlp/models/intent_slot_classification/multi_label_intent_slot_classification_model.py b/nemo/collections/nlp/models/intent_slot_classification/multi_label_intent_slot_classification_model.py index c689b97ab0a5..7a2bec1f2cc0 100644 --- a/nemo/collections/nlp/models/intent_slot_classification/multi_label_intent_slot_classification_model.py +++ b/nemo/collections/nlp/models/intent_slot_classification/multi_label_intent_slot_classification_model.py @@ -18,8 +18,8 @@ import numpy as np import numpy.typing as npt import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from sklearn.metrics import f1_score, precision_score, recall_score from torch.utils.data import DataLoader @@ -38,10 +38,10 @@ class MultiLabelIntentSlotClassificationModel(IntentSlotClassificationModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): - """ + """ Initializes BERT Joint Intent and Slot model. - Args: + Args: cfg: configuration object trainer: trainer for Pytorch Lightning """ @@ -69,12 +69,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): def _set_data_desc_to_cfg( self, cfg: DictConfig, data_dir: str, train_ds: DictConfig, validation_ds: DictConfig ) -> None: - """ - Creates MultiLabelIntentSlotDataDesc and copies generated values to Configuration object's data descriptor. - - Args: + """ + Creates MultiLabelIntentSlotDataDesc and copies generated values to Configuration object's data descriptor. + + Args: cfg: configuration object - data_dir: data directory + data_dir: data directory train_ds: training dataset file name validation_ds: validation dataset file name @@ -101,7 +101,10 @@ def _set_data_desc_to_cfg( if not hasattr(cfg, "class_labels") or cfg.class_labels is None: cfg.class_labels = {} cfg.class_labels = OmegaConf.create( - {"intent_labels_file": "intent_labels.csv", "slot_labels_file": "slot_labels.csv",} + { + "intent_labels_file": "intent_labels.csv", + "slot_labels_file": "slot_labels.csv", + } ) slot_labels_file = os.path.join(data_dir, cfg.class_labels.slot_labels_file) @@ -114,7 +117,7 @@ def _set_data_desc_to_cfg( OmegaConf.set_struct(cfg, True) def _reconfigure_classifier(self) -> None: - """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """ + """Method reconfigures the classifier depending on the settings of model cfg.data_desc""" self.classifier = SequenceTokenClassifier( hidden_size=self.bert_model.config.hidden_size, @@ -135,7 +138,8 @@ def _reconfigure_classifier(self) -> None: self.slot_loss = CrossEntropyLoss(logits_ndim=3) self.total_loss = AggregatorLoss( - num_inputs=2, weights=[self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight], + num_inputs=2, + weights=[self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight], ) # setup to track metrics @@ -161,12 +165,22 @@ def validation_step(self, batch, batch_idx) -> None: batch: batches of data from DataLoader batch_idx: batch idx from DataLoader - Returns: + Returns: None """ - (input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels,) = batch + ( + input_ids, + input_type_ids, + input_mask, + loss_mask, + subtokens_mask, + intent_labels, + slot_labels, + ) = batch intent_logits, slot_logits = self( - input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask, + input_ids=input_ids, + token_type_ids=input_type_ids, + attention_mask=input_mask, ) # calculate combined loss for intents and slots @@ -201,7 +215,7 @@ def _setup_dataloader_from_config(self, cfg: DictConfig) -> DataLoader: Args: cfg: configuration object - + Returns: DataLoader for model's data """ @@ -289,8 +303,8 @@ def prediction_probabilities(self, queries: List[str], test_ds: DictConfig) -> n def optimize_threshold(self, test_ds: DictConfig, file_name: str) -> None: """ - Set the optimal threshold of the model from performance on validation set. This threshold is used to round the - logits to 0 or 1. + Set the optimal threshold of the model from performance on validation set. This threshold is used to round the + logits to 0 or 1. Args: test_ds: location of test dataset @@ -361,16 +375,16 @@ def predict_from_examples( queries: text sequences test_ds: Dataset configuration section. threshold: Threshold for rounding prediction logits - + Returns: predicted_intents: model intent predictions with their probabilities - Example: [[('flight', 0.84)], [('airfare', 0.54), + Example: [[('flight', 0.84)], [('airfare', 0.54), ('flight', 0.73), ('meal', 0.24)]] predicted_slots: model slot predictions Example: ['O B-depart_date.month_name B-depart_date.day_number', 'O O B-flight_stop O O O'] - predicted_vector: model intent predictions for each individual query. Binary values within each list + predicted_vector: model intent predictions for each individual query. Binary values within each list indicate whether a class is prediced for the given query (1 for True, 0 for False) Example: [[1,0,0,0,0,0], [0,0,1,0,0,0]] """ diff --git a/nemo/collections/nlp/models/language_modeling/bert_lm_model.py b/nemo/collections/nlp/models/language_modeling/bert_lm_model.py index 6b03d86982b0..dc7103b67aa6 100644 --- a/nemo/collections/nlp/models/language_modeling/bert_lm_model.py +++ b/nemo/collections/nlp/models/language_modeling/bert_lm_model.py @@ -16,8 +16,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss, SmoothedCrossEntropyLoss @@ -75,7 +75,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): config_file = self.register_artifact('language_model.config_file', cfg.language_model.config_file) self.bert_model = get_lm_model( - config_file=config_file, config_dict=config_dict, vocab_file=vocab_file, trainer=trainer, cfg=cfg, + config_file=config_file, + config_dict=config_dict, + vocab_file=vocab_file, + trainer=trainer, + cfg=cfg, ) self.hidden_size = self.bert_model.config.hidden_size @@ -127,7 +131,9 @@ def forward(self, input_ids, attention_mask, token_type_ids): in the `nn.Module` in vanilla PyTorch. """ hidden_states = self.bert_model( - input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, ) if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] @@ -225,7 +231,9 @@ def _setup_preprocessed_dataloader(self, cfg: Optional[DictConfig]): files = [dataset] files.sort() dl = BertPretrainingPreprocessedDataloader( - data_files=files, max_predictions_per_seq=max_predictions_per_seq, batch_size=batch_size, + data_files=files, + max_predictions_per_seq=max_predictions_per_seq, + batch_size=batch_size, ) return dl diff --git a/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py index 0d75ab7cc706..c629db5af3c3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py @@ -208,6 +208,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): diff --git a/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py index 131f154d6709..7c3f3c194f14 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py @@ -108,6 +108,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index d1945139dee9..1def214113ee 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -252,6 +252,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, # TODO: handle this ): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bart_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bart_model.py index 1c768829e3e2..4a53edacb566 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bart_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bart_model.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model @@ -48,7 +48,9 @@ def _validate_cfg(self): @property def _build_train_valid_test_datasets_kwargs(self): """allows child classes to add kwargs to dataset building""" - return dict(delete_mask_prob=self._cfg.data.get('delete_mask_prob', 0.0),) + return dict( + delete_mask_prob=self._cfg.data.get('delete_mask_prob', 0.0), + ) def list_available_models(self): pass diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index d2a21e50e486..dee36b255297 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -23,12 +23,12 @@ import omegaconf import torch import torch.nn as nn +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator +from lightning.pytorch.trainer.trainer import Trainer +from lightning.pytorch.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf, open_dict from omegaconf.dictconfig import DictConfig -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator -from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException from nemo.collections.nlp.models.nlp_model import NLPModel from nemo.collections.nlp.modules.common.megatron.attention import HAVE_FLASH_ATTENTION diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index 2a356012c728..b00b6fcf0302 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -18,9 +18,9 @@ from typing import Any, Optional import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch import Tensor from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index 0eb5ea1c0048..e6945d1ada56 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -18,8 +18,8 @@ import torch import torch.nn.functional as F +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron import dataset_utils from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py b/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py index c0a4b6351530..d3829c3e8de1 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.glue_benchmark.glue_benchmark_dataset import ( TextToTextGLUEDataset, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py index c6b4d055ef6e..44860c3178f6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py @@ -21,8 +21,8 @@ import os import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -162,7 +162,7 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Loads a state_dict expecting the state_dict to contain key,values + Loads a state_dict expecting the state_dict to contain key,values only for the adapter parameters. """ for name, module in self.frozen_model.named_modules(): @@ -176,13 +176,13 @@ def load_state_dict(self, state_dict, strict: bool = True): def setup_optimizer_param_groups(self): """ - ModelPT override. Optimizer will get self._optimizer_param_groups. + ModelPT override. Optimizer will get self._optimizer_param_groups. Makes two optimizer param groups, one for the frozen model params - and one for the prompt-table/prompt-encoder params. The learning + and one for the prompt-table/prompt-encoder params. The learning rate for the frozen model's params will always be zero effectively freezing the model's params but still allowing for the needed gradients - to be passed around in pipeline parallel models. The prompt-encoder - and/or prompt table will use the learning rate set by the user. + to be passed around in pipeline parallel models. The prompt-encoder + and/or prompt table will use the learning rate set by the user. """ self.frozen_model.freeze() # Freeze the entire model opt_params = [] @@ -246,8 +246,8 @@ class MegatronGPTAdapterLearningModel(MegatronGPTBaseAdapterModel): Two adapter's are inserted into each Transformer layer in the base GPT Model. It is assumed that these set of adapters will then be trained for a specific task. - Once trained, the adapter weights will be saved and can be re-loaded - and infused into the same GPT Model for inference. + Once trained, the adapter weights will be saved and can be re-loaded + and infused into the same GPT Model for inference. """ def __init__(self, cfg: DictConfig, trainer: Trainer): @@ -295,7 +295,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): for adapter_key in self.adapter_name_keys: if model_utils.import_class_by_path(adapter_cfg._target_) in module.get_accepted_adapter_types(): module.add_adapter( - name=adapter_key, cfg=adapter_cfg, + name=adapter_key, + cfg=adapter_cfg, ) logging.info(f'After adding adapters:\n{self.frozen_model.summarize()}') @@ -313,8 +314,8 @@ class MegatronGPTInfusedAdapterModel(MegatronGPTBaseAdapterModel): Three adapter's are inserted into each Transformer layer in the base GPT Model. Each adapter is basically a vector that simply scales the key, value or ffn hidden representations. It is assumed that these set of adapters will then be trained for a specific task. - Once trained, the adapter weights will be saved and can be re-loaded - and infused into the same GPT Model for inference. + Once trained, the adapter weights will be saved and can be re-loaded + and infused into the same GPT Model for inference. """ def __init__(self, cfg: DictConfig, trainer: Trainer): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 8f541e5703e6..0e1bd83dfe1c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -24,11 +24,11 @@ import packaging import torch +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.parts.utils import apply_rope_scaling, extend_instance from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index 78f671142c1b..7d39459ae654 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -18,9 +18,9 @@ from typing import Any, List, Optional, Union import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 08bc5501363c..2d3f43b2f2a8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -17,9 +17,9 @@ from typing import Any, Optional import torch +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.metrics import MetricStringToTorchMetric from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py b/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py index 1e5a2f0c15c0..40e147b90903 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py @@ -13,8 +13,8 @@ # limitations under the License. import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_model import GriffinModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py index c53d231b2719..584a4b0572f7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel @@ -48,8 +48,8 @@ def _reset_activation_checkpointing_args(self): def on_validation_model_zero_grad(self) -> None: """ - Skip gradient zeroing at the beginning of validation routine. - This is needed when overlapping the AllGather of the updated parameters with the following valdation step. - """ + Skip gradient zeroing at the beginning of validation routine. + This is needed when overlapping the AllGather of the updated parameters with the following valdation step. + """ if not self.validation_param_sync_overlap: MegatronBaseModel.on_validation_model_zero_grad(self) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 7b92b9e25d69..e530a40d8aaa 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -18,11 +18,11 @@ from typing import Any, Dict, List, Optional import torch +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( MegatronPretrainingRandomSampler, @@ -32,12 +32,10 @@ from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import ( - AttnMaskType, MegatronTokenLevelEncoderDecoderModule, ) from nemo.collections.nlp.modules.common.megatron.utils import ( average_losses_across_data_parallel_group, - build_attention_mask_3d, get_params_for_weight_decay_optimization, ) from nemo.collections.nlp.modules.common.text_generation_utils import ( @@ -683,14 +681,13 @@ def fwd_output_and_loss_func(dataloader_iter, model): if self.mcore_t5: # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM - encoder_attn_mask_3d = build_attention_mask_3d( - encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding - ) - decoder_attn_mask_3d = build_attention_mask_3d( - decoder_attn_mask, decoder_attn_mask, AttnMaskType.causal - ) - enc_dec_attn_mask_3d = build_attention_mask_3d( - decoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + encoder_attn_mask = encoder_attn_mask < 0.5 + decoder_attn_mask = decoder_attn_mask < 0.5 + encoder_attn_mask_3d = encoder_attn_mask.unsqueeze(1).unsqueeze(1) + decoder_attn_mask_3d = decoder_attn_mask.unsqueeze(1).unsqueeze(1) + enc_dec_attn_mask_3d = ( + decoder_attn_mask_3d, + encoder_attn_mask_3d, ) output = model( # model is MCoreT5Model @@ -816,10 +813,8 @@ def fwd_output_only_func(dataloader_iter, model): encoder_attn_mask, ) = batch - # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM - encoder_attn_mask_3d = build_attention_mask_3d( - encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding - ) + encoder_attn_mask = encoder_attn_mask < 0.5 + encoder_attn_mask_3d = encoder_attn_mask.unsqueeze(1).unsqueeze(1) output = model( encoder_input_ids=encoder_input_ids, @@ -841,15 +836,13 @@ def fwd_output_only_func(dataloader_iter, model): decoder_attn_mask, ) = batch - # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM - encoder_attn_mask_3d = build_attention_mask_3d( - encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding - ) - decoder_attn_mask_3d = build_attention_mask_3d( - decoder_attn_mask, decoder_attn_mask, AttnMaskType.causal - ) - enc_dec_attn_mask_3d = build_attention_mask_3d( - decoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + encoder_attn_mask = encoder_attn_mask < 0.5 + decoder_attn_mask = decoder_attn_mask < 0.5 + encoder_attn_mask_3d = encoder_attn_mask.unsqueeze(1).unsqueeze(1) + decoder_attn_mask_3d = decoder_attn_mask.unsqueeze(1).unsqueeze(1) + enc_dec_attn_mask_3d = ( + decoder_attn_mask_3d, + encoder_attn_mask_3d, ) # re-transpose encoder_hidden_states from [batch, seq_len, hidden] to [seq_len, batch, hidden] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py index 4f0000dafaa2..ad92421ee607 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py @@ -13,8 +13,8 @@ # limitations under the License. import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.utils import logging diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py index ebcc47004711..cacdb1c190e7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel - __all__ = ['MegatronMambaSFTModel'] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py index 42323e503f7d..147c832f4b9a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py @@ -16,8 +16,8 @@ from typing import Any, List, Optional, Union import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( MegatronPretrainingRandomSampler, @@ -294,7 +294,10 @@ def training_step(self, batch, batch_idx): self.log('lr', lr, batch_size=1) self.log('global_step', self.trainer.global_step, prog_bar=True, batch_size=1) self.log( - 'consumed_samples', self._compute_consumed_samples_after_training_step(), prog_bar=True, batch_size=1, + 'consumed_samples', + self._compute_consumed_samples_after_training_step(), + prog_bar=True, + batch_size=1, ) self._reduced_loss_buffer = [] return lm_loss @@ -427,7 +430,10 @@ def build_pretraining_data_loader(self, dataset, consumed_samples): # Torch dataloader. return torch.utils.data.DataLoader( - dataset, batch_sampler=batch_sampler, num_workers=self.cfg.data.num_workers, pin_memory=True, + dataset, + batch_sampler=batch_sampler, + num_workers=self.cfg.data.num_workers, + pin_memory=True, ) def setup(self, stage=None): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py index 1eaec4238648..924da5825024 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py @@ -15,8 +15,8 @@ from functools import partial import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.data import ConcatMapDataset from nemo.collections.common.metrics import MetricStringToTorchMetric @@ -50,11 +50,13 @@ def build_all_datasets( - cfg, tokenizer, train_valid_test_num_samples, + cfg, + tokenizer, + train_valid_test_num_samples, ): """Build train, valid, and test RETRO datasets. - There is one to one mapping between data_prefix and knn_map_path. - Currently only supports one retrieval dataset. + There is one to one mapping between data_prefix and knn_map_path. + Currently only supports one retrieval dataset. """ train_dataset = RetroQAFineTuneDataset( cfg.train_ds.get('file_name'), @@ -97,7 +99,7 @@ def build_all_datasets( class MegatronRetroFinetuneModel(MegatronRetrievalModel): - """Finetune RETRO Model """ + """Finetune RETRO Model""" def build_train_valid_test_datasets(self): logging.info('Building RETRO datasets.') @@ -114,7 +116,9 @@ def build_train_valid_test_datasets(self): ] self._train_ds, self._validation_ds, self._test_ds = build_all_datasets( - cfg=self.cfg.data, tokenizer=self.tokenizer, train_valid_test_num_samples=train_valid_test_num_samples, + cfg=self.cfg.data, + tokenizer=self.tokenizer, + train_valid_test_num_samples=train_valid_test_num_samples, ) if self._train_ds is not None: logging.info(f'Length of train dataset: {len(self._train_ds)}') @@ -143,5 +147,9 @@ def build_pretraining_data_loader(self, dataset, consumed_samples): drop_last=True, ) return torch.utils.data.DataLoader( - dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0, pin_memory=True, + dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + pin_memory=True, ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py index a6bf75fb9444..493d512fd30e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py @@ -23,10 +23,10 @@ from typing import Any, Dict, Iterator, List, Optional, Union import torch +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( MegatronPretrainingRandomSampler, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py index cee1b11a160b..92827b31a259 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( get_datasets_weights_and_num_samples, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py index 31eb4519ded2..a6e6afc8b7eb 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py @@ -21,9 +21,9 @@ from typing import Any import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model @@ -60,7 +60,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.adapter_name_keys = [] def forward( - self, input_ids, dec_input, enc_mask, dec_mask, position_ids, taskname_ids, labels=None, inference=False, + self, + input_ids, + dec_input, + enc_mask, + dec_mask, + position_ids, + taskname_ids, + labels=None, + inference=False, ): # Call forward on T5 model with preprocessed embeddings if self.autocast_dtype == torch.float32: @@ -195,13 +203,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def setup_optimizer_param_groups(self): """ - ModelPT override. Optimizer will get self._optimizer_param_groups. + ModelPT override. Optimizer will get self._optimizer_param_groups. Makes two optimizer param groups, one for the frozen model params - and one for the prompt-table/prompt-encoder params. The learning + and one for the prompt-table/prompt-encoder params. The learning rate for the frozen model's params will always be zero effectively freezing the model's params but still allowing for the needed gradients - to be passed around in pipeline parallel models. The prompt-encoder - and/or prompt table will use the learning rate set by the user. + to be passed around in pipeline parallel models. The prompt-encoder + and/or prompt table will use the learning rate set by the user. """ self.frozen_model.freeze() # Freeze the entire model opt_params = [] @@ -266,7 +274,7 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Loads a state_dict expecting the state_dict to contain key,values + Loads a state_dict expecting the state_dict to contain key,values only for the adapter parameters. """ for name, module in self.frozen_model.named_modules(): @@ -319,7 +327,7 @@ def on_validation_epoch_end(self): gather_results_dedup = list(set(itertools.chain(*gather_results))) correct = 0 - for (input, pred, label) in gather_results_dedup: + for input, pred, label in gather_results_dedup: if pred == label: correct += 1 @@ -559,8 +567,8 @@ class MegatronT5InfusedAdapterModel(MegatronT5BaseAdapterModel): Three adapter's are inserted into each Transformer layer in the base GPT Model. Each adapter is basically a vector that simply scales the key, value or ffn hidden representations. It is assumed that these set of adapters will then be trained for a specific task. - Once trained, the adapter weights will be saved and can be re-loaded - and infused into the same GPT Model for inference. + Once trained, the adapter weights will be saved and can be re-loaded + and infused into the same GPT Model for inference. """ def __init__(self, cfg: DictConfig, trainer: Trainer): @@ -670,7 +678,7 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Loads a state_dict expecting the state_dict to contain key,values + Loads a state_dict expecting the state_dict to contain key,values only for the adapter parameters. """ encoder = self.frozen_model.enc_dec_model.enc_dec_model.encoder diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py index 0f5022795446..1df10403a9e7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py @@ -15,8 +15,8 @@ import enum import math +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import build_train_valid_test_datasets from nemo.collections.nlp.models.language_modeling.megatron_lm_encoder_decoder_model import ( @@ -79,7 +79,9 @@ def _validate_cfg(self): @property def _build_train_valid_test_datasets_kwargs(self): """allows child classes to add kwargs to dataset building""" - return dict(max_seq_length_dec=self._cfg.data.seq_length_dec,) + return dict( + max_seq_length_dec=self._cfg.data.seq_length_dec, + ) def _build_vocab(self): self.num_sentinel_tokens = self._cfg.tokenizer.num_sentinel_tokens @@ -210,9 +212,9 @@ def build_train_valid_test_datasets(self): ] if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): - train_valid_test_num_samples[ - 1 - ] = 1 # This is to make sure we only have one epoch on every validation iteration + train_valid_test_num_samples[1] = ( + 1 # This is to make sure we only have one epoch on every validation iteration + ) self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets( cfg=self._cfg, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py index 1f54cb87428e..187f24c884b7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py @@ -16,10 +16,10 @@ from typing import Any, List import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.t5_prompt_learning_dataset import T5PromptLearningDataset from nemo.collections.nlp.models.language_modeling.megatron_base_prompt_learning_model import ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py index c70f44925d33..6f9a69f27529 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py @@ -16,9 +16,9 @@ from typing import Dict, List import torch +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.data import ConcatMapDataset from nemo.collections.common.metrics import MetricStringToTorchMetric diff --git a/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py b/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py index 69db0d46e75e..3b8e1f819ea1 100644 --- a/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py +++ b/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py @@ -19,8 +19,8 @@ import numpy as np import torch import torch.utils.data as pt_data +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.common.losses import SmoothedCrossEntropyLoss from nemo.collections.common.metrics import GlobalAverageLossMetric @@ -59,9 +59,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): tokenizer_model=cfg.tokenizer.get("tokenizer_model", None), vocab_file=cfg.tokenizer.get("vocab_file", None), bpe_dropout=cfg.tokenizer.get("bpe_dropout", 0.0), - special_tokens=OmegaConf.to_container(cfg.tokenizer.special_tokens) - if cfg.tokenizer.get("special_tokens", None) - else None, + special_tokens=( + OmegaConf.to_container(cfg.tokenizer.special_tokens) + if cfg.tokenizer.get("special_tokens", None) + else None + ), ) # init superclass @@ -99,7 +101,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.encoder.embedding.token_embedding.weight - std_init_range = 1 / self.encoder.hidden_size ** 0.5 + std_init_range = 1 / self.encoder.hidden_size**0.5 # initialize weights if not using pretrained encoder if not self._cfg.encoder.get('pretrained', False): @@ -199,7 +201,12 @@ def on_test_epoch_end(self): self.test_step_outputs.clear() # free memory def setup_tokenizer( - self, tokenizer_name=None, tokenizer_model=None, vocab_file=None, bpe_dropout=0.0, special_tokens=None, + self, + tokenizer_name=None, + tokenizer_model=None, + vocab_file=None, + bpe_dropout=0.0, + special_tokens=None, ): supported_tokenizers = ['huggingface', 'sentencepiece', 'word'] diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 4461b417f311..b5f228f21e1a 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -19,10 +19,10 @@ import numpy as np import torch +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig from omegaconf.listconfig import ListConfig -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from sacrebleu import corpus_bleu from nemo.collections.nlp.data.common.sequence_to_sequence_dataset import ( diff --git a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_bottleneck_model.py b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_bottleneck_model.py index 41c6125ba05f..96077c4da82e 100644 --- a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_bottleneck_model.py +++ b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_bottleneck_model.py @@ -16,7 +16,7 @@ import numpy as np import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.common.losses import NLLLoss from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import MTBottleneckModelConfig @@ -184,7 +184,11 @@ def loss( output_mask = (tgt_labels != self.decoder_tokenizer.pad_id).type_as(tgt_log_probs) log_p_x_given_z_per_token = ( - -recon_loss_fn(log_probs=tgt_log_probs, labels=tgt_labels,).view(tgt_log_probs.shape[:2]) * output_mask + -recon_loss_fn( + log_probs=tgt_log_probs, + labels=tgt_labels, + ).view(tgt_log_probs.shape[:2]) + * output_mask ) # probability per sample @@ -216,7 +220,10 @@ def loss( if self.model_type in ["mim", "vae"]: # tokens = tgt_mask.sum() - q_z_given_x = torch.distributions.Normal(loc=z_mean, scale=torch.exp(0.5 * z_logv),) + q_z_given_x = torch.distributions.Normal( + loc=z_mean, + scale=torch.exp(0.5 * z_logv), + ) # average latent distribution to match averaging of observations if self.recon_per_token: # average latent per dimension - to heuristically match per-token reconstruction @@ -225,7 +232,10 @@ def loss( log_q_z_given_x = q_z_given_x.log_prob(z).sum(-1).sum(-1).mean() # build prior distribution - p_z = torch.distributions.Normal(loc=torch.zeros_like(z), scale=torch.ones_like(z),) + p_z = torch.distributions.Normal( + loc=torch.zeros_like(z), + scale=torch.ones_like(z), + ) if self.recon_per_token: # average latent distribution similar to averaging of observations log_p_z = p_z.log_prob(z).mean(-1).mean(-1).mean() @@ -267,7 +277,11 @@ def forward(self, src, src_mask, tgt, tgt_mask, timer=None): if timer is not None: timer.start("encoder") - enc_hiddens, enc_mask = self.encoder(input_ids=src, encoder_mask=src_mask, return_mask=True,) + enc_hiddens, enc_mask = self.encoder( + input_ids=src, + encoder_mask=src_mask, + return_mask=True, + ) # build posterior distribution q(x|z) z, z_mean, z_logv = self.encode_latent(hidden=enc_hiddens) @@ -283,7 +297,10 @@ def forward(self, src, src_mask, tgt, tgt_mask, timer=None): context_hiddens = self.latent2hidden(z) tgt_hiddens = self.decoder( - input_ids=tgt, decoder_mask=tgt_mask, encoder_embeddings=context_hiddens, encoder_mask=enc_mask, + input_ids=tgt, + decoder_mask=tgt_mask, + encoder_embeddings=context_hiddens, + encoder_mask=enc_mask, ) # build decoding distribution @@ -426,18 +443,25 @@ def eval_step(self, batch, batch_idx, mode, dataloader_idx=0): return_info=True, ) # pass cache to sampler in order to reuse encoder's output - cache = dict(z=z, z_mean=z_mean, z_mask=z_mask, timer=timer,) + cache = dict( + z=z, + z_mean=z_mean, + z_mask=z_mask, + timer=timer, + ) inputs, translations = self.batch_translate(src=src_ids, src_mask=src_mask, cache=cache) num_measurements = labels.shape[0] * labels.shape[1] if dataloader_idx == 0: getattr(self, f'{mode}_loss')( - loss=eval_loss, num_measurements=num_measurements, + loss=eval_loss, + num_measurements=num_measurements, ) else: getattr(self, f'{mode}_loss_{dataloader_idx}')( - loss=eval_loss, num_measurements=num_measurements, + loss=eval_loss, + num_measurements=num_measurements, ) np_tgt = tgt_ids.detach().cpu().numpy() ground_truths = [self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt] diff --git a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py index 708d4236be7f..78b701699259 100644 --- a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py +++ b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py @@ -25,9 +25,9 @@ import torch import torch.distributed as dist import torch.utils.data as pt_data +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.utilities import rank_zero_only from sacrebleu import corpus_bleu from nemo.collections.common.data import ConcatDataset @@ -120,17 +120,21 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): encoder_tokenizer, decoder_tokenizer = MTEncDecModel.setup_enc_dec_tokenizers( encoder_tokenizer_library=self.encoder_tokenizer_library, encoder_tokenizer_model=encoder_tokenizer_model, - encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0) - if cfg.encoder_tokenizer.get('bpe_dropout', 0.0) is not None - else 0.0, + encoder_bpe_dropout=( + cfg.encoder_tokenizer.get('bpe_dropout', 0.0) + if cfg.encoder_tokenizer.get('bpe_dropout', 0.0) is not None + else 0.0 + ), encoder_model_name=cfg.encoder.get('model_name') if hasattr(cfg.encoder, 'model_name') else None, encoder_r2l=cfg.encoder_tokenizer.get('r2l', False), decoder_tokenizer_library=self.decoder_tokenizer_library, encoder_tokenizer_vocab_file=encoder_vocab_file, decoder_tokenizer_model=decoder_tokenizer_model, - decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0) - if cfg.decoder_tokenizer.get('bpe_dropout', 0.0) is not None - else 0.0, + decoder_bpe_dropout=( + cfg.decoder_tokenizer.get('bpe_dropout', 0.0) + if cfg.decoder_tokenizer.get('bpe_dropout', 0.0) is not None + else 0.0 + ), decoder_model_name=cfg.decoder.get('model_name') if hasattr(cfg.decoder, 'model_name') else None, decoder_r2l=cfg.decoder_tokenizer.get('r2l', False), special_tokens=self.special_tokens, @@ -254,7 +258,7 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight # TODO: encoder and decoder with different hidden size? - std_init_range = 1 / self.encoder.hidden_size ** 0.5 + std_init_range = 1 / self.encoder.hidden_size**0.5 # initialize weights if not using pretrained encoder/decoder if not self._cfg.encoder.get('pretrained', False): @@ -341,7 +345,10 @@ def filter_predicted_ids(cls, ids, decoder_tokenizer): return ids def test_encoder_ids(self, ids, raise_error=False): - invalid_ids = torch.logical_or((ids >= self.encoder_tokenizer.vocab_size).any(), (ids < 0).any(),) + invalid_ids = torch.logical_or( + (ids >= self.encoder_tokenizer.vocab_size).any(), + (ids < 0).any(), + ) if raise_error and invalid_ids: raise ValueError("Encoder ids are out of range (tip: check encoder tokenizer)") @@ -349,7 +356,10 @@ def test_encoder_ids(self, ids, raise_error=False): return not invalid_ids def test_decoder_ids(self, ids, raise_error=False): - invalid_ids = torch.logical_or((ids >= self.decoder_tokenizer.vocab_size).any(), (ids < 0).any(),) + invalid_ids = torch.logical_or( + (ids >= self.decoder_tokenizer.vocab_size).any(), + (ids < 0).any(), + ) if raise_error and invalid_ids: raise ValueError("Decoder ids are out of range (tip: check decoder tokenizer)") @@ -655,7 +665,10 @@ def setup_training_data(self, train_data_config: Optional[DictConfig]): multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, ) - self._train_dl = MTEncDecModel._setup_dataloader_from_config(cfg=train_data_config, dataset=self._train_ds,) + self._train_dl = MTEncDecModel._setup_dataloader_from_config( + cfg=train_data_config, + dataset=self._train_ds, + ) # Need to set this because if using an IterableDataset, the length of the dataloader is the total number # of samples rather than the number of batches, and this messes up the tqdm progress bar. @@ -714,7 +727,9 @@ def setup_validation_data(self, val_data_config: Optional[DictConfig]): for dataloader_idx in range(len(self._validation_dl)): if dataloader_idx == 0: setattr( - self, f'val_loss', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), + self, + f'val_loss', + GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), ) else: setattr( @@ -737,7 +752,9 @@ def setup_test_data(self, test_data_config: Optional[DictConfig]): for dataloader_idx in range(len(self._test_dl)): if dataloader_idx == 0: setattr( - self, f'test_loss', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), + self, + f'test_loss', + GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), ) else: setattr( @@ -886,13 +903,15 @@ def _setup_dataloader_from_config(cls, cfg, dataset): return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, - sampler=None - if ( - cfg.get("use_tarred_dataset", False) - or cfg.get("dataset_type", "") == "tarred" - or isinstance(dataset, ConcatDataset) - ) - else sampler, + sampler=( + None + if ( + cfg.get("use_tarred_dataset", False) + or cfg.get("dataset_type", "") == "tarred" + or isinstance(dataset, ConcatDataset) + ) + else sampler + ), num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), @@ -983,9 +1002,11 @@ def _setup_eval_dataloader_from_config(cls, cfg, datasets): torch.utils.data.DataLoader( dataset=dataset, batch_size=1, - sampler=None - if (cfg.get("use_tarred_dataset", False) or isinstance(datasets[0], ConcatDataset)) - else sampler, + sampler=( + None + if (cfg.get("use_tarred_dataset", False) or isinstance(datasets[0], ConcatDataset)) + else sampler + ), num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), @@ -1188,7 +1209,10 @@ def translate( ) if return_beam_scores: _, all_translations, scores, best_translations = self.batch_translate( - src, src_mask, return_beam_scores=True, cache=cache, + src, + src_mask, + return_beam_scores=True, + cache=cache, ) return_val = all_translations, scores, best_translations else: diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index b27c00c5d7c3..0c61b085bc7f 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -19,13 +19,13 @@ from typing import Any, Mapping, Optional, Union import torch -from lightning_fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.pytorch import Trainer +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml +from lightning.pytorch.utilities import rank_zero_only +from lightning.pytorch.utilities.migration import pl_legacy_patch from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.migration import pl_legacy_patch from transformers import TRANSFORMERS_CACHE from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer diff --git a/nemo/collections/nlp/models/question_answering/qa_base_model.py b/nemo/collections/nlp/models/question_answering/qa_base_model.py index 7ca78f2e136e..cb07e43c3dc1 100644 --- a/nemo/collections/nlp/models/question_answering/qa_base_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_base_model.py @@ -15,8 +15,8 @@ from typing import Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.data.question_answering.data_processor.qa_processing import ( EVALUATION_MODE, diff --git a/nemo/collections/nlp/models/question_answering/qa_bert_model.py b/nemo/collections/nlp/models/question_answering/qa_bert_model.py index d4bdef6d871d..4036b23999d8 100644 --- a/nemo/collections/nlp/models/question_answering/qa_bert_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_bert_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers.models.bert.tokenization_bert import BasicTokenizer from nemo.collections.common.losses import SpanningLoss diff --git a/nemo/collections/nlp/models/question_answering/qa_gpt_model.py b/nemo/collections/nlp/models/question_answering/qa_gpt_model.py index 059cf5625f15..f8c883643fe0 100644 --- a/nemo/collections/nlp/models/question_answering/qa_gpt_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_gpt_model.py @@ -16,8 +16,8 @@ from typing import List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoModelForCausalLM from nemo.collections.nlp.data.question_answering.data_processor.qa_processing import QAProcessor diff --git a/nemo/collections/nlp/models/question_answering/qa_model.py b/nemo/collections/nlp/models/question_answering/qa_model.py index 2147d7d6a5bf..01b07bb8b3b0 100644 --- a/nemo/collections/nlp/models/question_answering/qa_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_model.py @@ -16,8 +16,8 @@ from typing import Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.cuda.amp import autocast from nemo.collections.common.losses import SpanningLoss diff --git a/nemo/collections/nlp/models/question_answering/qa_s2s_model.py b/nemo/collections/nlp/models/question_answering/qa_s2s_model.py index 5ad959fd1b6f..a703e23bc837 100644 --- a/nemo/collections/nlp/models/question_answering/qa_s2s_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_s2s_model.py @@ -16,8 +16,8 @@ from typing import List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.cuda.amp import autocast from transformers import AutoModelForSeq2SeqLM diff --git a/nemo/collections/nlp/models/rag/custom_bert_embedder.py b/nemo/collections/nlp/models/rag/custom_bert_embedder.py index d27ee98a14ef..84361e2728b5 100644 --- a/nemo/collections/nlp/models/rag/custom_bert_embedder.py +++ b/nemo/collections/nlp/models/rag/custom_bert_embedder.py @@ -15,10 +15,10 @@ from typing import Any, List import torch +from lightning.pytorch.trainer.trainer import Trainer from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.embeddings import BaseEmbedding from omegaconf import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.information_retrieval.megatron_bert_embedding_model import MegatronBertEmbeddingModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy diff --git a/nemo/collections/nlp/models/rag/custom_gpt_llm.py b/nemo/collections/nlp/models/rag/custom_gpt_llm.py index f26a86cfaaf7..1bbeed38991b 100644 --- a/nemo/collections/nlp/models/rag/custom_gpt_llm.py +++ b/nemo/collections/nlp/models/rag/custom_gpt_llm.py @@ -14,10 +14,10 @@ from typing import Any +from lightning.pytorch.trainer.trainer import Trainer from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.llms import CompletionResponse, CompletionResponseGen, CustomLLM, LLMMetadata from llama_index.core.llms.callbacks import llm_completion_callback -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam diff --git a/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py b/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py index d9e08f6764fc..6d4974993bcb 100644 --- a/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py +++ b/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py @@ -17,8 +17,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.data.spellchecking_asr_customization import ( diff --git a/nemo/collections/nlp/models/text2sparql/text2sparql_model.py b/nemo/collections/nlp/models/text2sparql/text2sparql_model.py index 6503364fc07e..df7eefa310bb 100644 --- a/nemo/collections/nlp/models/text2sparql/text2sparql_model.py +++ b/nemo/collections/nlp/models/text2sparql/text2sparql_model.py @@ -19,8 +19,8 @@ from typing import Dict, List, Optional, Tuple import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from transformers import AutoModel, BartForConditionalGeneration, EncoderDecoderModel from nemo.collections.common.metrics import Perplexity @@ -145,7 +145,10 @@ def training_step(self, batch: Tuple, batch_idx: int) -> Dict: """ input_ids, input_mask, decoder_input_ids, labels = batch loss = self.forward( - input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids, labels=labels, + input_ids=input_ids, + attention_mask=input_mask, + decoder_input_ids=decoder_input_ids, + labels=labels, )[0] tensorboard_logs = {"train_loss": loss, "lr": self._optimizer.param_groups[0]["lr"]} @@ -159,7 +162,10 @@ def validation_step(self, batch: Tuple, batch_idx: int) -> Dict: """ input_ids, input_mask, decoder_input_ids, labels = batch loss, logits = self.forward( - input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids, labels=labels, + input_ids=input_ids, + attention_mask=input_mask, + decoder_input_ids=decoder_input_ids, + labels=labels, )[:2] self.validation_perplexity(logits=logits) diff --git a/nemo/collections/nlp/models/text_classification/text_classification_model.py b/nemo/collections/nlp/models/text_classification/text_classification_model.py index 033447304bbf..b2da2fe21701 100644 --- a/nemo/collections/nlp/models/text_classification/text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/text_classification_model.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.data.text_classification import TextClassificationDataset, calc_class_weights diff --git a/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py b/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py index 4c11dc157b2b..ddcb3a774055 100644 --- a/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py +++ b/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.data.text_normalization_as_tagging import ( @@ -289,7 +289,7 @@ def on_test_epoch_end(self): # Functions for inference @torch.no_grad() def _infer(self, sents: List[str]) -> List[List[int]]: - """ Main function for Inference + """Main function for Inference Args: sents: A list of input sentences (lowercase spoken-domain words separated by space). diff --git a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_lexical_audio_model.py b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_lexical_audio_model.py index 69df9b6ac009..bd42517a5720 100644 --- a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_lexical_audio_model.py +++ b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_lexical_audio_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer from torch.nn import Linear from tqdm import tqdm @@ -53,27 +53,27 @@ def update_model_config_to_support_adapter(model_cfg): class PunctuationCapitalizationLexicalAudioModel(PunctuationCapitalizationModel): """ - A model for restoring punctuation and capitalization in text using lexical and audio features. - - The model consists of a language model and two multilayer perceptrons (MLP) on top the fusion of LM and AM. The first - MLP serves for punctuation prediction and the second is for capitalization prediction. You can use only BERT-like - HuggingFace language models (model ``forward`` method accepts ``input_ids``, ``token_types_ids``, - ``attention_mask`` arguments). See more about model config options :ref:`here`. - And any :class:`~nemo.collections.asr.models.EncDecCTCModel` which has encoder module which is used as an AM. - - For training and testing use dataset - :class:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_dataset.BertPunctuationCapitalizationDataset` with parameter ``use_audio`` set to ``True``, - for training on huge amounts of data which cannot be loaded into memory simultaneously use - :class:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_tarred_dataset.BertPunctuationCapitalizationTarredDataset` with parameter ``use_audio`` set to ``True``. - - Args: - cfg: a model configuration. It should follow dataclass - :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationLexicalAudioModelConfig` - See an example of full config in - `nemo/examples/nlp/token_classification/conf/punctuation_capitalization_lexical_audio_config.yaml - `_ - trainer: an instance of a PyTorch Lightning trainer - """ + A model for restoring punctuation and capitalization in text using lexical and audio features. + + The model consists of a language model and two multilayer perceptrons (MLP) on top the fusion of LM and AM. The first + MLP serves for punctuation prediction and the second is for capitalization prediction. You can use only BERT-like + HuggingFace language models (model ``forward`` method accepts ``input_ids``, ``token_types_ids``, + ``attention_mask`` arguments). See more about model config options :ref:`here`. + And any :class:`~nemo.collections.asr.models.EncDecCTCModel` which has encoder module which is used as an AM. + + For training and testing use dataset + :class:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_dataset.BertPunctuationCapitalizationDataset` with parameter ``use_audio`` set to ``True``, + for training on huge amounts of data which cannot be loaded into memory simultaneously use + :class:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_tarred_dataset.BertPunctuationCapitalizationTarredDataset` with parameter ``use_audio`` set to ``True``. + + Args: + cfg: a model configuration. It should follow dataclass + :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationLexicalAudioModelConfig` + See an example of full config in + `nemo/examples/nlp/token_classification/conf/punctuation_capitalization_lexical_audio_config.yaml + `_ + trainer: an instance of a PyTorch Lightning trainer + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None) -> None: super().__init__(cfg, trainer) @@ -199,31 +199,31 @@ def forward( features_length: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Executes a forward pass through the model. For more details see ``forward`` method of :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationLexicalAudioModelConfig` - and ``forward`` method of :class:'~nemo.collections.asr.models.EncDecCTCModel' - - Args: - input_ids (:obj:`torch.Tensor`): an integer torch tensor of shape ``[Batch, Time]``. Contains encoded - source tokens. - attention_mask (:obj:`torch.Tensor`): a boolean torch tensor of shape ``[Batch, Time]``. Contains an - attention mask for excluding paddings. - token_type_ids (:obj:`torch.Tensor`): an integer torch Tensor of shape ``[Batch, Time]``. Contains an index - of segment to which a token belongs. If ``token_type_ids`` is not ``None``, then it should be a zeros - tensor. - features (:obj:`torch.Tensor`): tensor that represents a batch of raw audio signals, - of shape [B, T]. T here represents timesteps, with 1 second of audio represented as - sample_rate number of floating point values. - features_length (:obj:`torch.Tensor`): Vector of length B, that contains the individual lengths of the audio - sequences. - - Returns: - :obj:`Tuple[torch.Tensor, torch.Tensor]`: a tuple containing - - - ``punct_logits`` (:obj:`torch.Tensor`): a float torch tensor of shape - ``[Batch, Time, NumPunctuationLabels]`` containing punctuation logits - - ``capit_logits`` (:obj:`torch.Tensor`): a float torch tensor of shape - ``[Batch, Time, NumCapitalizationLabels]`` containing capitalization logits - """ + Executes a forward pass through the model. For more details see ``forward`` method of :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationLexicalAudioModelConfig` + and ``forward`` method of :class:'~nemo.collections.asr.models.EncDecCTCModel' + + Args: + input_ids (:obj:`torch.Tensor`): an integer torch tensor of shape ``[Batch, Time]``. Contains encoded + source tokens. + attention_mask (:obj:`torch.Tensor`): a boolean torch tensor of shape ``[Batch, Time]``. Contains an + attention mask for excluding paddings. + token_type_ids (:obj:`torch.Tensor`): an integer torch Tensor of shape ``[Batch, Time]``. Contains an index + of segment to which a token belongs. If ``token_type_ids`` is not ``None``, then it should be a zeros + tensor. + features (:obj:`torch.Tensor`): tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + sample_rate number of floating point values. + features_length (:obj:`torch.Tensor`): Vector of length B, that contains the individual lengths of the audio + sequences. + + Returns: + :obj:`Tuple[torch.Tensor, torch.Tensor]`: a tuple containing + + - ``punct_logits`` (:obj:`torch.Tensor`): a float torch tensor of shape + ``[Batch, Time, NumPunctuationLabels]`` containing punctuation logits + - ``capit_logits`` (:obj:`torch.Tensor`): a float torch tensor of shape + ``[Batch, Time, NumCapitalizationLabels]`` containing capitalization logits + """ self.update_max_seq_length(seq_length=features.size(1), device=features.device) lexical_hidden_states = self.bert_model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask @@ -232,7 +232,8 @@ def forward( lexical_hidden_states = lexical_hidden_states[0] processed_signal, processed_signal_length = self.audio_encoder.preprocessor( - input_signal=features, length=features_length, + input_signal=features, + length=features_length, ) if self.audio_encoder.spec_augmentation is not None and self.training: @@ -301,49 +302,49 @@ def add_punctuation_capitalization( target_sr: Optional[int] = None, ) -> List[str]: """ - Adds punctuation and capitalization to the queries. Use this method for inference. - - Parameters ``max_seq_length``, ``step``, ``margin`` are for controlling the way queries are split into segments - which are processed by the model. Parameter ``max_seq_length`` is a length of a segment after tokenization - including special tokens [CLS] in the beginning and [SEP] in the end of a segment. Parameter ``step`` is a - shift between consequent segments. Parameter ``margin`` is used to exclude negative effect of subtokens near - borders of segments which have only one side context. - - If segments overlap, probabilities of overlapping predictions are multiplied and then the label with - corresponding to the maximum probability is selected. - - Args: - queries (:obj:`List[str]`): lower cased text without punctuation. - batch_size (:obj:`List[str]`, `optional`): batch size to use during inference. If ``batch_size`` parameter - is not provided, then it will be equal to length of ``queries`` list. - max_seq_length (:obj:`int`, `optional`, defaults to :obj:`64`): maximum sequence length of a segment after - tokenization including :code:`[CLS]` and :code:`[SEP]` tokens. - step (:obj:`int`, `optional`, defaults to :obj:`8`): relative shift of consequent segments into which long - queries are split. Long queries are split into segments which can overlap. Parameter ``step`` controls - such overlapping. Imagine that queries are tokenized into characters, ``max_seq_length=5``, and - ``step=2``. In such case, query ``"hello"`` is tokenized into segments - ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. - margin (:obj:`int`, `optional`, defaults to :obj:`16`): number of subtokens in the beginning and the end of - segments which are not used for prediction computation. The first segment does not have left margin and - the last segment does not have right margin. For example, if an input sequence is tokenized into - characters, ``max_seq_length=5``, ``step=1``, and ``margin=1``, then query ``"hello"`` will be - tokenized into segments ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'e', 'l', 'l', '[SEP]'], - ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. These segments are passed to the model. Before final predictions - computation, margins are removed. In the next list, subtokens which logits are not used for final - predictions computation are marked with asterisk: ``[['[CLS]'*, 'h', 'e', 'l'*, '[SEP]'*], - ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]``. - return_labels (:obj:`bool`, `optional`, defaults to :obj:`False`): whether to return labels in NeMo format - (see :ref:`nlp/punctuation_and_capitalization/NeMo Data Format`) instead of queries with restored - punctuation and capitalization. - dataloader_kwargs (:obj:`Dict[str, Any]`, `optional`): an optional dictionary with parameters of PyTorch - data loader. May include keys: ``'num_workers'``, ``'pin_memory'``, ``'worker_init_fn'``, - ``'prefetch_factor'``, ``'persistent_workers'``. - audio_queries (:obj:`List[str]`, `optional`): paths to audio files. - target_sr (:obj:`int`, `optional`): target sample rate for audios. - Returns: - :obj:`List[str]`: a list of queries with restored capitalization and punctuation if - ``return_labels=False``, else a list of punctuation and capitalization labels strings for all queries - """ + Adds punctuation and capitalization to the queries. Use this method for inference. + + Parameters ``max_seq_length``, ``step``, ``margin`` are for controlling the way queries are split into segments + which are processed by the model. Parameter ``max_seq_length`` is a length of a segment after tokenization + including special tokens [CLS] in the beginning and [SEP] in the end of a segment. Parameter ``step`` is a + shift between consequent segments. Parameter ``margin`` is used to exclude negative effect of subtokens near + borders of segments which have only one side context. + + If segments overlap, probabilities of overlapping predictions are multiplied and then the label with + corresponding to the maximum probability is selected. + + Args: + queries (:obj:`List[str]`): lower cased text without punctuation. + batch_size (:obj:`List[str]`, `optional`): batch size to use during inference. If ``batch_size`` parameter + is not provided, then it will be equal to length of ``queries`` list. + max_seq_length (:obj:`int`, `optional`, defaults to :obj:`64`): maximum sequence length of a segment after + tokenization including :code:`[CLS]` and :code:`[SEP]` tokens. + step (:obj:`int`, `optional`, defaults to :obj:`8`): relative shift of consequent segments into which long + queries are split. Long queries are split into segments which can overlap. Parameter ``step`` controls + such overlapping. Imagine that queries are tokenized into characters, ``max_seq_length=5``, and + ``step=2``. In such case, query ``"hello"`` is tokenized into segments + ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. + margin (:obj:`int`, `optional`, defaults to :obj:`16`): number of subtokens in the beginning and the end of + segments which are not used for prediction computation. The first segment does not have left margin and + the last segment does not have right margin. For example, if an input sequence is tokenized into + characters, ``max_seq_length=5``, ``step=1``, and ``margin=1``, then query ``"hello"`` will be + tokenized into segments ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'e', 'l', 'l', '[SEP]'], + ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. These segments are passed to the model. Before final predictions + computation, margins are removed. In the next list, subtokens which logits are not used for final + predictions computation are marked with asterisk: ``[['[CLS]'*, 'h', 'e', 'l'*, '[SEP]'*], + ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]``. + return_labels (:obj:`bool`, `optional`, defaults to :obj:`False`): whether to return labels in NeMo format + (see :ref:`nlp/punctuation_and_capitalization/NeMo Data Format`) instead of queries with restored + punctuation and capitalization. + dataloader_kwargs (:obj:`Dict[str, Any]`, `optional`): an optional dictionary with parameters of PyTorch + data loader. May include keys: ``'num_workers'``, ``'pin_memory'``, ``'worker_init_fn'``, + ``'prefetch_factor'``, ``'persistent_workers'``. + audio_queries (:obj:`List[str]`, `optional`): paths to audio files. + target_sr (:obj:`int`, `optional`): target sample rate for audios. + Returns: + :obj:`List[str]`: a list of queries with restored capitalization and punctuation if + ``return_labels=False``, else a list of punctuation and capitalization labels strings for all queries + """ if len(queries) == 0: return [] @@ -408,7 +409,9 @@ def add_punctuation_capitalization( acc_probs[q_i] = b_probs_i else: all_preds[q_i], acc_probs[q_i] = self._move_acc_probs_to_token_preds( - all_preds[q_i], acc_probs[q_i], start_word_id - len(all_preds[q_i]), + all_preds[q_i], + acc_probs[q_i], + start_word_id - len(all_preds[q_i]), ) acc_probs[q_i] = self._update_accumulated_probabilities(acc_probs[q_i], b_probs_i) for all_preds, acc_probs in [(all_punct_preds, acc_punct_probs), (all_capit_preds, acc_capit_probs)]: diff --git a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py index 6e2d1f5762ec..8cf153dfdf76 100644 --- a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py +++ b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py @@ -20,8 +20,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from tqdm import tqdm from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss @@ -812,7 +812,13 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u raise ValueError( f"If `use_tarred_dataset` is `False`, then you need to provide `tokens_in_batch` parameter." ) - text_file, labels_file, = Path(cfg.ds_item) / cfg.text_file, Path(cfg.ds_item) / cfg.labels_file + ( + text_file, + labels_file, + ) = ( + Path(cfg.ds_item) / cfg.text_file, + Path(cfg.ds_item) / cfg.labels_file, + ) if cfg.audio_file: audio_file = Path(cfg.ds_item) / cfg.audio_file if self.label_ids_are_set: @@ -1010,7 +1016,8 @@ def _transform_logit_to_prob_and_remove_margins_and_extract_word_probs( stm = self._remove_margins(stm, margin, keep_left=first, keep_right=last) for b_probs, logits in [(b_punct_probs, pl), (b_capit_probs, cl)]: p = torch.nn.functional.softmax( - self._remove_margins(logits, margin, keep_left=first, keep_right=last)[stm], dim=-1, + self._remove_margins(logits, margin, keep_left=first, keep_right=last)[stm], + dim=-1, ) b_probs.append(p.detach().cpu().numpy()) return b_punct_probs, b_capit_probs, new_start_word_ids @@ -1191,7 +1198,9 @@ def add_punctuation_capitalization( ): inp_ids, inp_type_ids, inp_mask, subtokens_mask, start_word_ids, query_ids, is_first, is_last = batch punct_logits, capit_logits = self.forward( - input_ids=inp_ids.to(d), token_type_ids=inp_type_ids.to(d), attention_mask=inp_mask.to(d), + input_ids=inp_ids.to(d), + token_type_ids=inp_type_ids.to(d), + attention_mask=inp_mask.to(d), ) _res = self._transform_logit_to_prob_and_remove_margins_and_extract_word_probs( punct_logits, capit_logits, subtokens_mask, start_word_ids, margin, is_first, is_last @@ -1208,7 +1217,9 @@ def add_punctuation_capitalization( acc_probs[q_i] = b_probs_i else: all_preds[q_i], acc_probs[q_i] = self._move_acc_probs_to_token_preds( - all_preds[q_i], acc_probs[q_i], start_word_id - len(all_preds[q_i]), + all_preds[q_i], + acc_probs[q_i], + start_word_id - len(all_preds[q_i]), ) acc_probs[q_i] = self._update_accumulated_probabilities(acc_probs[q_i], b_probs_i) for all_preds, acc_probs in [(all_punct_preds, acc_punct_probs), (all_capit_preds, acc_capit_probs)]: diff --git a/nemo/collections/nlp/models/token_classification/token_classification_model.py b/nemo/collections/nlp/models/token_classification/token_classification_model.py index 0b465bae663c..99bb2328b956 100644 --- a/nemo/collections/nlp/models/token_classification/token_classification_model.py +++ b/nemo/collections/nlp/models/token_classification/token_classification_model.py @@ -16,8 +16,8 @@ from typing import List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.common.losses import CrossEntropyLoss diff --git a/nemo/collections/nlp/models/zero_shot_intent_recognition/zero_shot_intent_model.py b/nemo/collections/nlp/models/zero_shot_intent_recognition/zero_shot_intent_model.py index e65f3d7749eb..07e0826c712c 100644 --- a/nemo/collections/nlp/models/zero_shot_intent_recognition/zero_shot_intent_model.py +++ b/nemo/collections/nlp/models/zero_shot_intent_recognition/zero_shot_intent_model.py @@ -18,8 +18,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.nlp.data.zero_shot_intent_recognition.zero_shot_intent_dataset import ( ZeroShotIntentDataset, @@ -155,7 +155,6 @@ def predict( entailment_idx=1, contradiction_idx=0, ) -> List[Dict]: - """ Given a list of queries and a list of candidate labels, return a ranked list of labels and scores for each query. diff --git a/nemo/collections/nlp/modules/common/lm_utils.py b/nemo/collections/nlp/modules/common/lm_utils.py index af6fc9ecb0a7..86792059b28f 100644 --- a/nemo/collections/nlp/modules/common/lm_utils.py +++ b/nemo/collections/nlp/modules/common/lm_utils.py @@ -17,8 +17,8 @@ from typing import List, Optional, Union from attr import asdict +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.nlp.modules.common.bert_module import BertModule from nemo.collections.nlp.modules.common.decoder_module import DecoderModule diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index da9c98fd94ea..e306a0a9b6b7 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -82,19 +82,20 @@ def forward( rotary_pos_emb: Tensor = None, rotary_pos_cos: Tensor = None, rotary_pos_sin: Tensor = None, + attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, ): hidden_states = super().forward( - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - inference_params, - packed_seq_params, + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + inference_params=inference_params, + packed_seq_params=packed_seq_params, ) mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER) @@ -232,6 +233,7 @@ def forward( packed_seq_params=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, ): # hidden_states: [sq, b, h] diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index c1b4e3023e42..d5784081f6f0 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -380,6 +380,7 @@ def forward( rotary_pos_emb=None, # rotary positional embedding relative_position_bias=None, checkpoint_core_attention=False, + return_scores=False, ): # hidden_states: [sq, b, h] @@ -398,7 +399,9 @@ def forward( # Some consistency check. if inference_max_sequence_len: - assert self.inference_current_sequence_len < self.inference_key_memory.size(0) + # Added equals to as inference key_memory size refers to cross-attention key size + # which is already equal to the current "sequence length" + assert self.inference_current_sequence_len <= self.inference_key_memory.size(0) assert inference_max_sequence_len == self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len # is not provided, make sure there is no potential memory left @@ -433,28 +436,40 @@ def forward( (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( mixed_x_layer, 3, contiguous_split_chunks=True ) - else: + else: # Else in cross_attention # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv_layer, _ = self.key_value(encoder_output) - if self.is_adapter_available(): - lora_kv_adapter = self.get_adapter_module(AdapterName.LORA_KV_ADAPTER) - if lora_kv_adapter and self.adapter_cfg[AdapterName.LORA_KV_ADAPTER]['enabled']: - lora_mixed_kv_layer = lora_kv_adapter(encoder_output) - mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - if self.megatron_legacy: - mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True) - mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + if ( + inference_max_sequence_len is None + ) or self.inference_current_sequence_len < inference_max_sequence_len: + # If we are in traning and inference_max_sequence_len is None + # Or we haven't cached the key and value part of cross attention in the decoder on step 0, + # Do the caching + mixed_kv_layer, _ = self.key_value(encoder_output) + if self.is_adapter_available(): + lora_kv_adapter = self.get_adapter_module(AdapterName.LORA_KV_ADAPTER) + if lora_kv_adapter and self.adapter_cfg[AdapterName.LORA_KV_ADAPTER]['enabled']: + lora_mixed_kv_layer = lora_kv_adapter(encoder_output) + mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + if self.megatron_legacy: + mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( - mixed_kv_layer, 2, contiguous_split_chunks=True - ) + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( + mixed_kv_layer, 2, contiguous_split_chunks=True + ) + else: + # else if we are in inference and have already cached key, value, can just read cache + key_layer = self.inference_key_memory[: self.inference_current_sequence_len, ...] + value_layer = self.inference_value_memory[: self.inference_current_sequence_len, ...] + if attention_mask is not None: + attention_mask = attention_mask[..., -1, :].unsqueeze(-2) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) @@ -490,7 +505,9 @@ def forward( if rotary_pos_emb is not None: rotary_pos_emb = rotary_pos_emb if isinstance(rotary_pos_emb, tuple) else ((rotary_pos_emb,) * 2) - if inference_max_sequence_len: + # If we are in cross attention (inference_current_sequence_len == inference_max_sequence_len == inference_key_memory.size(0)) + # We only need to cache this once + if inference_max_sequence_len and self.inference_current_sequence_len < inference_max_sequence_len: # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) @@ -501,7 +518,7 @@ def forward( key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask - if attention_mask is not None: + if attention_mask is not None and self.attention_type == AttnType.self_attn: attention_mask = attention_mask[..., start:end, :end] # adjust the key rotary positional embedding if rotary_pos_emb is not None: @@ -569,7 +586,10 @@ def forward( relative_position_bias=relative_position_bias, headscale_tensor=self.head_scale_tensor if self.headscale else None, inference_mode=inference_max_sequence_len is not None and query_layer.shape[0] == 1, + return_scores=return_scores, ) + if return_scores: + context_layer, attention_probs = context_layer # ================= # Output. [sq, b, h] @@ -585,6 +605,9 @@ def forward( if get_key_value: output = [output, present] + if return_scores: + output = [output, attention_probs] + return output, bias @@ -857,6 +880,7 @@ def forward( relative_position_bias=None, headscale_tensor=None, inference_mode=None, + return_scores=None, ): b, np, sq, sk, hn = ( query_layer.size(1), @@ -914,9 +938,27 @@ def forward( # relative_position_bias [b, np, sq, sk] # context_layer [b, np, sq, hn] # ================================================== - context_layer = self.attn_fn( - query_layer, key_layer, value_layer, attention_mask, relative_position_bias, inference_mode - ) + if not return_scores: + context_layer = self.attn_fn( + query_layer, + key_layer, + value_layer, + attention_mask, + relative_position_bias, + inference_mode, + ) + else: + # SpeechLLM TTS modifications + context_layer = self.torch_attention_with_prior( + query_layer, + key_layer, + value_layer, + attention_mask, + relative_position_bias, + inference_mode, + return_scores=return_scores, + ) + context_layer, attention_probs = context_layer if headscale_tensor is not None: context_layer = context_layer * headscale_tensor @@ -928,7 +970,10 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) - return context_layer + if return_scores: + return context_layer, attention_probs + else: + return context_layer def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode): sq, b, np, hn = query_layer.shape @@ -986,6 +1031,69 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a return context_layer + def torch_attention_with_prior( + self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode, return_scores=False + ): + sq, b, np, hn = query_layer.shape + sk = key_layer.shape[0] + + if self.multi_query_attention: + query_layer = rearrange(query_layer, 'sq b np hn -> b (np sq) hn') + key_layer = rearrange(key_layer, 'sk b 1 hn -> b hn sk') + value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn') + else: + query_layer = rearrange(query_layer, 'sq b np hn -> (b np) sq hn') + key_layer = rearrange(key_layer, 'sk b np hn -> (b np) hn sk') + value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn') + + matmul_input_buffer = torch.empty( + query_layer.shape[0], + query_layer.shape[1], + key_layer.shape[2], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, + key_layer, + beta=0.0, + alpha=(1.0 / self.norm_factor) if self.normalize_attention_scores else 1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(b, np, sq, sk) + + if attention_bias is not None: + # attention_bias is not None only for cross attention layers right now in T5 + attention_scores = torch.log_softmax(attention_scores, dim=-1) + attention_bias + + _attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with tensor_parallel.random.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(_attention_probs) + else: + attention_probs = self.attention_dropout(_attention_probs) + + # change view [b * np, sq, sk] + attention_probs = rearrange(attention_probs, 'b np sq sk -> (b np) sq sk') + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = rearrange(context_layer, '(b np) sq hn -> b np sq hn', np=np) + + if return_scores: + # return context_layer, _attention_probs + return context_layer, attention_scores + else: + return context_layer + def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode): query_layer = rearrange(query_layer, 'sq b np hn -> b sq np hn') key_layer = rearrange(key_layer, 'sk b np hn -> b sk np hn') diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py index 712ce10b81b5..d2945a061584 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py @@ -13,7 +13,7 @@ # limitations under the License. """Transformer based language model.""" -from ast import Mod +from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_transformer_decoder import MegatronTransformerDecoderModule from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import ( MegatronRetrievalTransformerDecoderModule, @@ -87,7 +87,7 @@ def get_decoder_model( transformer_block_type="pre_ln", hidden_steps=-1, parent_model_type=ModelType.encoder_or_decoder, - layer_type=None, + layer_type=LayerType.decoder, chunk_size=64, layer_number_offset=0, # this is use only for attention norm_factor scaling megatron_legacy=False, @@ -158,6 +158,7 @@ def get_decoder_model( moe_dropout=moe_dropout, position_embedding_type=position_embedding_type, use_flash_attention=use_flash_attention, + layer_type=layer_type, ) elif arch == "retro": decoder = MegatronRetrievalTransformerDecoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py index c4192dacb45a..744a6e18c8b1 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py @@ -13,7 +13,6 @@ # limitations under the License. """Transformer based language model.""" -from ast import Mod import torch @@ -46,8 +45,7 @@ class MegatronTransformerEncoderDecoderModule(MegatronModule): - """Transformer encoder-decoder model. - """ + """Transformer encoder-decoder model.""" def __init__( self, @@ -85,6 +83,8 @@ def __init__( encoder_attn_mask_type = AttnMaskType.padding elif hasattr(encoder.model, 'self_attn_mask_type'): encoder_attn_mask_type = encoder.model.self_attn_mask_type + elif isinstance(encoder.model, torch.nn.ModuleList) and hasattr(encoder.model[0], 'self_attn_mask_type'): + encoder_attn_mask_type = encoder.model[0].self_attn_mask_type else: raise AttributeError( "Could not find an attribute for encoder self_attn_mask_type, make sure it is set when instatiating the encoder or pass it to the constructor of this class." @@ -142,7 +142,11 @@ def encode( # apply hidden transformations if needed if self.hiddens_module is not None: enc_output = self.hiddens_module.apply_hidden_transforms( - {"hiddens": enc_output, "hiddens_mask": self.get_hiddens_mask(enc_attn_mask),}, batch_data=batch_data, + { + "hiddens": enc_output, + "hiddens_mask": self.get_hiddens_mask(enc_attn_mask), + }, + batch_data=batch_data, ) return enc_output @@ -157,6 +161,11 @@ def decode( dec_get_key_value=False, dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): if self.decoder is None: raise ValueError(f"Cannot call .decode(...) when self.decoder is None.") @@ -170,6 +179,11 @@ def decode( enc_attn_mask=enc_attn_mask, dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers, ) return dec_output @@ -191,6 +205,11 @@ def forward( dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, batch_data=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): # encoder if enc_output is None: @@ -207,7 +226,10 @@ def forward( assert self.encoder_hidden_state is not None enc_output = self.encoder_hidden_state else: - enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask) + if isinstance(enc_output_attn_mask, list): + enc_attn_mask = [mask.to(enc_attn_mask[midx]) for midx, mask in enumerate(enc_output_attn_mask)] + else: + enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask) if self.decoder is None or output_enc_hidden_only: return enc_output @@ -216,15 +238,22 @@ def forward( dec_output = self.decode( dec_input=dec_input, dec_attn_mask=dec_attn_mask, - enc_output=enc_output["enc_output"] # enc_output is a dict if we used hidden transformations - if self.hiddens_module is not None - else enc_output, + enc_output=( + enc_output["enc_output"] # enc_output is a dict if we used hidden transformations + if self.hiddens_module is not None + else enc_output + ), # Adjust encoder attention mask if encoder is a perceiver. enc_attn_mask=self.get_hiddens_mask(enc_attn_mask), dec_layer_past=dec_layer_past, dec_get_key_value=dec_get_key_value, dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers, ) # if self.hiddens_module is not None enc_output is a dict, else it is a torch.tensor @@ -246,7 +275,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars= def load_state_dict(self, state_dict, strict=True): """Customized load.""" - self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) - self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) - if self.hiddens_module is not None: - self.hiddens_module.load_state_dict(state_dict[self._hiddens_module], strict=strict) + try: + self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) + self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) + if self.hiddens_module is not None: + self.hiddens_module.load_state_dict(state_dict[self._hiddens_module], strict=strict) + except KeyError as e: + super().load_state_dict(state_dict, strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py index 601eb320e8fc..3d2b2c1ecc13 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py @@ -14,7 +14,10 @@ """Transformer based language model.""" from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule -from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import MegatronTransformerEncoderModule +from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import ( + MegatronTransformerEncoderModule, + MultiMegatronTransformerEncoderModule, +) from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import ( MegatronRetrievalTransformerEncoderModule, ) @@ -108,6 +111,7 @@ def get_encoder_model( version=1, # model version position_embedding_type='learned_absolute', use_flash_attention=False, + n_transformers=1, ): """Build language model and return along with the key to save.""" @@ -167,6 +171,51 @@ def get_encoder_model( position_embedding_type=position_embedding_type, use_flash_attention=use_flash_attention, ) + elif arch == "multi_transformer": + encoder = MultiMegatronTransformerEncoderModule( + config=config, + n_transformers=n_transformers, + init_method=init_method, + output_layer_init_method=scaled_init_method, + hidden_size=hidden_size, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + encoder_attn_mask_type=encoder_attn_mask_type, + pre_process=pre_process, + post_process=post_process, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + ffn_dropout=ffn_dropout, + precision=precision, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + activations_checkpoint_granularity=activations_checkpoint_granularity, + layernorm_epsilon=layernorm_epsilon, + bias_activation_fusion=bias_activation_fusion, + bias_dropout_add_fusion=bias_dropout_add_fusion, + masked_softmax_fusion=masked_softmax_fusion, + persist_layer_norm=persist_layer_norm, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + activation=activation, + bias=bias, + normalization=normalization, + transformer_block_type=transformer_block_type, + headscale=headscale, + parent_model_type=parent_model_type, + megatron_legacy=megatron_legacy, + normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, + ) + elif arch == "retro": encoder = MegatronRetrievalTransformerEncoderModule( config=config, diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py index 4a05a08820e7..14677552492b 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py @@ -52,8 +52,7 @@ class MegatronTransformerDecoderModule(MegatronModule, Exportable, MegatronDecoderModule): - """Transformer decoder model. - """ + """Transformer decoder model.""" def __init__( self, @@ -97,6 +96,7 @@ def __init__( moe_dropout=0.0, position_embedding_type='learned_absolute', use_flash_attention=False, + layer_type=LayerType.decoder, ): super(MegatronTransformerDecoderModule, self).__init__(config=config) @@ -121,7 +121,7 @@ def __init__( # Transformer. self.model = ParallelTransformer( config=config, - layer_type=LayerType.decoder, + layer_type=layer_type, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, num_layers=self.num_layers, @@ -165,7 +165,7 @@ def __init__( self._model_key = 'model' def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" self.model.set_input_tensor(input_tensor) def forward( @@ -178,15 +178,41 @@ def forward( get_key_value=False, dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): # convert to Megatron mask dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=dec_attn_mask, attn_mask_type=self.model_attn_mask_type, - ) - enc_dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=enc_attn_mask, attn_mask_type=AttnMaskType.padding, + source_mask=dec_attn_mask, + target_mask=dec_attn_mask, + attn_mask_type=self.model_attn_mask_type, ) + if isinstance(enc_output, list): + assert len(enc_output) == len(enc_attn_mask) + enc_dec_attn_mask_3d = [] + for i in range(len(enc_output)): + enc_dec_attn_mask_3d.append( + attn_mask_postprocess( + build_attention_mask_3d( + source_mask=dec_attn_mask, + target_mask=enc_attn_mask[i], + attn_mask_type=AttnMaskType.padding, + ) + ) + ) + else: + enc_dec_attn_mask_3d = attn_mask_postprocess( + build_attention_mask_3d( + source_mask=dec_attn_mask, + target_mask=enc_attn_mask, + attn_mask_type=AttnMaskType.padding, + ) + ) + # transformer decoder dec_output = self.model( dec_input, @@ -194,9 +220,14 @@ def forward( layer_past=layer_past, get_key_value=get_key_value, encoder_output=enc_output, - enc_dec_attn_mask=attn_mask_postprocess(enc_dec_attn_mask_3d), + enc_dec_attn_mask=enc_dec_attn_mask_3d, self_attention_relative_position_bias=dec_self_attention_relative_position_bias, cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers, ) return dec_output diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py index 7a41e1300066..a9b80868558f 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py @@ -13,6 +13,8 @@ # limitations under the License. """Transformer based language model.""" +import torch + from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_encoder_module import MegatronEncoderModule from nemo.collections.nlp.modules.common.megatron.module import MegatronModule @@ -163,7 +165,7 @@ def __init__( self._model_key = 'model' def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" self.model.set_input_tensor(input_tensor) def forward( @@ -173,6 +175,7 @@ def forward( layer_past=None, get_key_value=False, enc_self_attention_relative_position_bias=None, + set_inference_key_value_memory=False, ): # convert to Megatron mask if self.use_flash_attention: @@ -180,7 +183,9 @@ def forward( else: enc_attn_mask_3d = attn_mask_postprocess( build_attention_mask_3d( - source_mask=enc_attn_mask, target_mask=enc_attn_mask, attn_mask_type=self.model_attn_mask_type, + source_mask=enc_attn_mask, + target_mask=enc_attn_mask, + attn_mask_type=self.model_attn_mask_type, ) ) @@ -192,6 +197,7 @@ def forward( get_key_value=get_key_value, self_attention_relative_position_bias=enc_self_attention_relative_position_bias, cross_attention_relative_position_bias=None, + set_inference_key_value_memory=set_inference_key_value_memory, ) return enc_output @@ -231,3 +237,214 @@ def load_state_dict(self, state_dict, strict=True): state_dict_ = state_dict_self_attention self.model.load_state_dict(state_dict_, strict=strict) + + +class MultiMegatronTransformerEncoderModule(MegatronModule, Exportable, MegatronEncoderModule): + """Transformer encoder model.""" + + def __init__( + self, + config: ModelParallelConfig, + n_transformers, + init_method, + output_layer_init_method, + hidden_size, + ffn_hidden_size, + num_layers, + num_attention_heads, + apply_query_key_layer_scaling=True, + kv_channels=None, + pre_process=True, + post_process=True, + encoder_attn_mask_type=AttnMaskType.padding, + hidden_dropout=0.1, + attention_dropout=0.1, + ffn_dropout=0.0, + precision=16, + fp32_residual_connection=False, + activations_checkpoint_method=None, + activations_checkpoint_num_layers=1, + activations_checkpoint_granularity=None, + layernorm_epsilon=1e-5, + bias_activation_fusion=True, + bias_dropout_add_fusion=True, + masked_softmax_fusion=True, + persist_layer_norm=False, + openai_gelu=False, + onnx_safe=False, + activation='gelu', + bias=True, + normalization='layernorm', + transformer_block_type='pre_ln', + headscale=False, + parent_model_type=ModelType.encoder_or_decoder, + megatron_legacy=False, + normalize_attention_scores=True, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, + position_embedding_type='learned_absolute', + use_flash_attention=False, + ): + super(MultiMegatronTransformerEncoderModule, self).__init__(config=config) + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = hidden_size + self.num_layers = num_layers + self.init_method = init_method + self.model_attn_mask_type = encoder_attn_mask_type + self.hidden_dropout = hidden_dropout + self.output_layer_init_method = output_layer_init_method + self.parent_model_type = parent_model_type + self.normalization = normalization + self.transformer_block_type = transformer_block_type + self.use_flash_attention = use_flash_attention + + if kv_channels is None: + + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + + # Transformer List + self.model = [] + for i in range(n_transformers): + transformer = ParallelTransformer( + config=config, + layer_type=LayerType.encoder, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + num_layers=self.num_layers, + hidden_size=self.hidden_size, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + self_attn_mask_type=self.model_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process, + precision=precision, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + activations_checkpoint_granularity=activations_checkpoint_granularity, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + ffn_dropout=ffn_dropout, + bias_activation_fusion=bias_activation_fusion, + bias_dropout_add_fusion=bias_dropout_add_fusion, + masked_softmax_fusion=masked_softmax_fusion, + persist_layer_norm=persist_layer_norm, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + activation=activation, + bias=bias, + normalization=normalization, + transformer_block_type=transformer_block_type, + headscale=headscale, + model_type=parent_model_type, + megatron_legacy=megatron_legacy, + normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, + ) + self.model.append(transformer) + + self.model = torch.nn.ModuleList(self.model) + + self._model_key = 'model' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + for mi in range(len(self.model)): + self.model[mi].set_input_tensor(input_tensor) + + def forward( + self, + enc_input, + enc_attn_mask, + layer_past=None, + get_key_value=False, + enc_self_attention_relative_position_bias=None, + set_inference_key_value_memory=False, + ): + + assert isinstance(enc_input, list) + assert len(enc_input) == len(self.model) + assert isinstance(enc_attn_mask, list) + assert len(enc_attn_mask) == len(self.model) + assert isinstance(enc_self_attention_relative_position_bias, list) + # convert to Megatron mask + enc_outputs = [] + for encoder_number in range(len(self.model)): + enc_input_ = enc_input[encoder_number] + enc_attn_mask_ = enc_attn_mask[encoder_number] + enc_self_attention_relative_position_bias_ = enc_self_attention_relative_position_bias[encoder_number] + + if self.use_flash_attention: + enc_attn_mask_3d = enc_attn_mask_ < 0.5 + else: + enc_attn_mask_3d = attn_mask_postprocess( + build_attention_mask_3d( + source_mask=enc_attn_mask_, + target_mask=enc_attn_mask_, + attn_mask_type=self.model_attn_mask_type, + ) + ) + + # transformer encoder + enc_output = self.model[encoder_number]( + enc_input_, + enc_attn_mask_3d, + layer_past=layer_past, + get_key_value=get_key_value, + self_attention_relative_position_bias=enc_self_attention_relative_position_bias_, + cross_attention_relative_position_bias=None, + set_inference_key_value_memory=set_inference_key_value_memory, + ) + + enc_outputs.append(enc_output) + + return enc_outputs + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + + state_dict_[self._model_key] = self.model.state_dict_for_save_checkpoint(destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Encoder. + if self._model_key in state_dict: + state_dict_ = state_dict[self._model_key] + # for backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # for backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + self.model.load_state_dict(state_dict_, strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index ccd485427c3c..a4efb2992166 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -113,7 +113,7 @@ def decoder_cross_attention_relative_position_embeddings_weight(self): def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): if not self.share_token_embeddings: - raise Exception('initialize_word_embeddings() was called but ' 'share_token_embeddings is false') + raise Exception('initialize_word_embeddings() was called but share_token_embeddings is false') # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline @@ -140,7 +140,10 @@ def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = tensor_parallel.VocabParallelEmbedding( - vocab_size, hidden_size, init_method=init_method, config=self.config, + vocab_size, + hidden_size, + init_method=init_method, + config=self.config, ) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index b7b377940eb4..e68113949aa7 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -42,6 +42,7 @@ ) from nemo.collections.nlp.modules.common.megatron.vocab_parallel_cross_entropy import vocab_parallel_cross_entropy from nemo.core.classes.mixins import adapter_mixins +from nemo.utils import logging try: from apex.transformer.enums import AttnMaskType, ModelType @@ -67,7 +68,11 @@ HAVE_MEGATRON_CORE = False -__all__ = ["MegatronTokenLevelHead", "MegatronTokenLevelEncoderDecoderModule"] +__all__ = [ + "MegatronTokenLevelHead", + "MegatronTokenLevelEncoderDecoderModule", + "MegatronTokenLevelEncoderDecoderSpeechLLMModule", +] class MegatronTokenLevelHead(MegatronModule): @@ -252,6 +257,7 @@ def __init__( moe_dropout=encoder_cfg.get('moe_dropout', 0.0), position_embedding_type=encoder_cfg.get('position_embedding_type', 'learned_absolute'), use_flash_attention=encoder_cfg.get('use_flash_attention', False), + n_transformers=encoder_cfg.get('n_transformers', 1), ) if add_decoder: @@ -388,6 +394,7 @@ def __init__( moe_dropout=decoder_cfg.get('moe_dropout', 0.0), position_embedding_type=decoder_cfg.get('position_embedding_type', 'learned_absolute'), use_flash_attention=decoder_cfg.get('use_flash_attention', False), + layer_type=decoder_cfg.get('layer_type', LayerType.decoder), ) hiddens_module = get_hiddens_module(hiddens_cfg, model_parallel_cfg=config) @@ -410,6 +417,7 @@ def __init__( if add_decoder and post_process: if share_decoder_tokens_head_embeddings: + # parallel_output is True if TP > 1 (3b model) self.tokens_head = MegatronTokenLevelHead( self.word_embeddings_weight().size(0), parallel_output, bias=tokens_head_bias ) @@ -469,7 +477,7 @@ def _validate_config(self): return encoder_kv_channels, decoder_kv_channels def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" # This is usually handled in schedules.py but some inference code still # gives us non-lists or None @@ -566,7 +574,8 @@ def forward( if self.add_encoder and self.encoder_relative_position_embedding is not None: encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( - query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, + query_seq_length=enc_seq_length, + key_seq_length=enc_seq_length, ) if output_enc_hidden_only: @@ -604,8 +613,11 @@ def forward( query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) ) if not self.decoder_cfg.relative_position_bias_self_attention_only: - decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding( - query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length, + decoder_cross_attention_relative_position_bias = ( + self.decoder_cross_attention_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), + key_seq_length=enc_seq_length, + ) ) else: decoder_cross_attention_relative_position_bias = None @@ -656,7 +668,8 @@ def forward( # check if hiddens is used if self.hiddens_cfg is not None: loss_dict = self.enc_dec_model.hiddens_module.apply_loss_transforms( - outputs=enc_output, batch_data=batch_data, + outputs=enc_output, + batch_data=batch_data, ) loss_dict["tokens_loss"] = tokens_loss # We need to store default output in a known key, so that we can mimic default behaviour @@ -708,8 +721,437 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars= def load_state_dict(self, state_dict, strict=True): """Customized load.""" - - self.encoder_embedding.encoder_embeddingload_state_dict(state_dict[self._encoder_embedding_key], strict=strict) + self.encoder_embedding.load_state_dict(state_dict[self._encoder_embedding_key], strict=strict) self.decoder_embedding.load_state_dict(state_dict[self._decoder_embedding_key], strict=strict) self.enc_dec_model.load_state_dict(state_dict[self._enc_dec_model_key], strict=strict) self.tokens_head.load_state_dict(state_dict[self._tokens_head_key], strict=strict) + + +class MegatronTokenLevelEncoderDecoderSpeechLLMModule(MegatronTokenLevelEncoderDecoderModule): + def __init__(self, *args, **kwargs): + super(MegatronTokenLevelEncoderDecoderSpeechLLMModule, self).__init__(*args, **kwargs) + # Overridden in MegatronT5SpeechLMModel constructor + self.seq_pattern = "parallel" + self.speech_head_type = "token_level" + self.attn_prior_scaledown_start_step = 10000 + self.attn_prior_end_step = 11000 + self.use_alignment_loss = False + self.return_all_crossattention_probs = False + self.logging_step = False + self.num_cross_attention_heads = 12 # 12 for 220m T5, 16 for 11b T5 + self.enc_output_to_layers = None + + def get_decoder_embeddings(self, dec_input_ids, dec_position_ids, token_type_ids): + if dec_input_ids.dim() <= 2: + dec_input = self.decoder_embedding(dec_input_ids, dec_position_ids, token_type_ids=token_type_ids) + else: + dec_input = None + for i in range(dec_input_ids.size()[1]): + if i == 0: + # For the first channel (text + first layer of speech), use the decoder embedding layer + dec_input = self.decoder_embedding( + dec_input_ids[:, i, :], dec_position_ids, token_type_ids=token_type_ids + ) + else: + # For the rest of the channels (speech), use the speech embedding layer. No need for position, since already added in first layer. + current = self.speech_tokens_embeddings[i - 1](dec_input_ids[:, i, :]).permute(1, 0, 2) + # @pneekhara - Commenting the below because we always want to include all channels for speech. + # @pneekhara - include_channel_flag can become 0 when doing autoregressive inference and the first timestep is zeros + # For text inputs, only include 1st channel embeddings. Zero-out others. + # include_channel_flag = (torch.sum(dec_input_ids[:, i, :], dim=1) > 0).float() # [B] + # current = current * include_channel_flag.unsqueeze(0).unsqueeze(2) + dec_input = dec_input + current + + return dec_input + + def forward( + self, + enc_input_ids=None, + enc_attn_mask=None, + dec_input_ids=None, + dec_attn_mask=None, + token_type_ids=None, + labels=None, + batch_data=None, # additional data to be passed to hiddens module + enc_output=None, # Result of running the entire encoder + enc_output_attn_mask=None, + enc_input=None, # Result of running encoder embedding only + output_enc_hidden_only=False, + speech_mask=None, + cross_attention_prior=None, + text_limits=None, + global_step=None, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + ): + """ + Return value is per token / per dimension (i.e., non collapsed loss value) + """ + ( + encoder_self_attention_relative_position_bias, + decoder_self_attention_relative_position_bias, + decoder_cross_attention_relative_position_bias, + ) = (None, None, None) + + if enc_input is not None and enc_output is not None: + raise ValueError( + """Both enc_input and enc_output are not None. + You should only be passing one of them. + enc_input is the result of the encoder embedding layer + enc_output is the result of running the entire transformer encoder.""" + ) + + # In order of precedence, we use enc_output, enc_input, and then enc_input_ids to determine the encoder sequence length. + if enc_output is not None: + # If enc_output is provided in `batch_for_pipeline`, we need to transpose it from [B x S x H] -> [S x B x H]. + if isinstance(enc_output, list): + encoder_self_attention_relative_position_bias = [None for _ in enc_output] + enc_output = [x.transpose(0, 1) for x in enc_output] + enc_seq_length = [x.size(0) for x in enc_output] + else: + enc_output = enc_output.transpose(0, 1) + enc_seq_length = enc_output.size(0) + elif enc_input is not None: + # If enc_input is provided, we need to transpose it from [B x S x H] -> [S x B x H]. + if isinstance(enc_input, list): + encoder_self_attention_relative_position_bias = [None for _ in enc_input] + enc_input = [x.transpose(0, 1) for x in enc_input] + enc_seq_length = [x.size(0) for x in enc_input] + else: + enc_input = enc_input.transpose(0, 1) + enc_seq_length = enc_input.size(0) + # Only need to run encoder embedding and position ids if enc_input or enc_output is not provided. + elif enc_input_ids is not None: + enc_seq_length = enc_input_ids.size(1) + if self.pre_process and self.add_encoder: + # We don't need position ids for RPE, because the embedding layer does not have position embeddings. + if self.encoder_relative_position_embedding is None: + enc_input_ids_p = enc_input_ids[:, 0, :] if enc_input_ids.dim() == 3 else enc_input_ids + enc_position_ids = build_position_ids(enc_input_ids_p) + else: + enc_position_ids = None + enc_input = self.encoder_embedding(enc_input_ids, enc_position_ids, token_type_ids=token_type_ids) + if self.is_adapter_available(): + _sq, _bs, _hs = enc_input.size() + ptuning_adapter = self.get_adapter_module(AdapterName.PTUNING_ADAPTER) + v = ptuning_adapter.virtual_tokens + if ( + ptuning_adapter and _sq >= v + ): # The sequence should be longer the v to insert virtual embeddings. + virtual_embeddings = ptuning_adapter(_bs) + enc_input = enc_input[ + v:, :, : + ] # the first v tokens are pads so that they can be swapped out with virtual embeddings. + enc_input = torch.concat([virtual_embeddings, enc_input], dim=0) + else: + enc_input = None + else: + # This should only happen with PP > 1 for enc-dec prompt learning models + enc_seq_length = enc_attn_mask.size(1) + + if self.add_encoder and self.encoder_relative_position_embedding is not None: + encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( + query_seq_length=enc_seq_length, + key_seq_length=enc_seq_length, + ) + + if output_enc_hidden_only: + # When pipeline parallel > 1 we need to make sure encoder exist (will be missing in decoder) + # SpeechT5 should not go here for inference + if enc_output is None and self.enc_dec_model.encoder is not None: + enc_output = self.enc_dec_model.encode( + enc_input=enc_input, + enc_attn_mask=enc_attn_mask, + enc_layer_past=None, + enc_get_key_value=False, + enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, + batch_data=batch_data, + ) + else: + enc_output = self.enc_dec_model.encoder_hidden_state + + return enc_output + else: + if enc_output_attn_mask is None: + enc_output_attn_mask = enc_attn_mask + + if self.pre_process and self.add_decoder: + # We don't need position ids for RPE, because the embedding layer does not have position embeddings. + if self.decoder_relative_position_embedding is None: + dec_input_ids_p = dec_input_ids[:, 0, :] if dec_input_ids.dim() == 3 else dec_input_ids + dec_position_ids = build_position_ids(dec_input_ids_p) + else: + dec_position_ids = None + dec_input = self.get_decoder_embeddings(dec_input_ids, dec_position_ids, token_type_ids) + if not set_inference_key_value_memory and (decoder_max_sequence_len or encoder_max_sequence_len): + # In inference + # On step 0 when set_inference_key_value_memory is True, we need all inputs in case + # we are using decoder context + # Else on step >= 1, only need last input + logging.debug("Clipping dec_input and only keep the last input.") + dec_input = dec_input[-1, :, :].unsqueeze(0) # shape (b, embed_dim) + else: + # Note: This is when the decoder itself is split across PP ranks. + dec_input = None + + if self.add_decoder and self.decoder_relative_position_embedding is not None: + decoder_self_attention_relative_position_bias = self.decoder_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) + ) + if not self.decoder_cfg.relative_position_bias_self_attention_only: + decoder_cross_attention_relative_position_bias = ( + self.decoder_cross_attention_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), + key_seq_length=enc_seq_length, + ) + ) + else: + decoder_cross_attention_relative_position_bias = None + + return_all_crossattention_probs = self.return_all_crossattention_probs + single_encoder = False + if not isinstance(cross_attention_prior, list): + single_encoder = True + cross_attention_prior = [cross_attention_prior] + + decoder_cross_attention_relative_position_bias = [] + for _cross_attention_prior in cross_attention_prior: + _decoder_cross_attention_relative_position_bias = None + if _cross_attention_prior is not None: + # cross_attention_prior shape [B, dec_len, enc_len] + # Repeat it to make it [B, 12, dec_len, enc_len] + attn_prior_end_step = self.attn_prior_end_step + attn_prior_scaledown_start_step = self.attn_prior_scaledown_start_step + num_attention_heads = self.num_cross_attention_heads + assert attn_prior_scaledown_start_step <= attn_prior_end_step + logging.debug( + f"attn_prior_scaledown_start_step: {attn_prior_scaledown_start_step}, attn_prior_scaledown_start_step: {attn_prior_end_step}" + ) + if global_step >= attn_prior_end_step: + _decoder_cross_attention_relative_position_bias = None + elif global_step > attn_prior_scaledown_start_step and global_step < attn_prior_end_step: + total_annealing_steps = attn_prior_end_step - attn_prior_scaledown_start_step + curr_annealing_step = global_step - attn_prior_scaledown_start_step + curr_cross_attention_prior = _cross_attention_prior + ( + (1.0 - _cross_attention_prior) * curr_annealing_step / total_annealing_steps + ) + _decoder_cross_attention_relative_position_bias = curr_cross_attention_prior.unsqueeze( + 1 + ).repeat(1, num_attention_heads, 1, 1) + _decoder_cross_attention_relative_position_bias = torch.log( + _decoder_cross_attention_relative_position_bias + 1e-8 + ) + else: + _decoder_cross_attention_relative_position_bias = _cross_attention_prior.unsqueeze(1).repeat( + 1, num_attention_heads, 1, 1 + ) + _decoder_cross_attention_relative_position_bias = torch.log( + _decoder_cross_attention_relative_position_bias + 1e-8 + ) + decoder_cross_attention_relative_position_bias.append(_decoder_cross_attention_relative_position_bias) + + return_all_crossattention_probs = return_all_crossattention_probs or self.logging_step + + if single_encoder: + decoder_cross_attention_relative_position_bias = decoder_cross_attention_relative_position_bias[0] + + output = self.enc_dec_model( + enc_input=enc_input, + enc_attn_mask=enc_attn_mask, + dec_input=dec_input, + dec_attn_mask=dec_attn_mask, + enc_layer_past=None, + enc_get_key_value=False, + enc_output=enc_output, + enc_output_attn_mask=enc_output_attn_mask, + dec_layer_past=None, + dec_get_key_value=False, + enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, + dec_self_attention_relative_position_bias=decoder_self_attention_relative_position_bias, + dec_cross_attention_relative_position_bias=decoder_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + batch_data=batch_data, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=self.enc_output_to_layers, + ) + + alignment_loss = None + if self.post_process and self.add_decoder: + dec_output, enc_output = output # [s, b, h] + if return_all_crossattention_probs: + dec_output, attention_scores = dec_output + attention_probs = [ + torch.softmax(attention_score, dim=-1) + for lidx, attention_score in enumerate(attention_scores) + if lidx in self.alignment_decoder_layerids + ] + + if text_limits is not None and self.use_alignment_loss and hasattr(self, "forward_sum_loss"): + attention_scores_filtered = [ + attention_scores[lidx] for lidx in self.alignment_decoder_layerids + ] + attention_scores_combined = torch.cat(attention_scores_filtered, dim=1) + text_start_idx = text_limits[0, 0].item() + assert torch.all( + text_limits[:, 0] == text_start_idx + ) # all texts should start at the same index + end_offset = self.alignment_text_end_offset + # align_every_n_head: eg if set to 2, will skip every other head + # if set to 12, will select 1 head from every layer + align_every_n_head = self.align_every_n_head + dec_start_idx = self.decoder_context_len + 1 # +1 to remove bos + attention_scores_sliced = attention_scores_combined[ + :, ::align_every_n_head, dec_start_idx:, text_start_idx : -(2 + end_offset) + ] # -2 to remove eos and pad + attention_logprobs = ( + attention_scores_sliced # not taking log_softmax, since we will do that in loss function + ) + attention_logprobs = torch.mean(attention_logprobs, dim=1, keepdim=True) + dec_len = torch.sum(dec_attn_mask, dim=1) - dec_start_idx + enc_len = text_limits[:, 1] - text_limits[:, 0] - end_offset + alignment_loss = self.forward_sum_loss( + attn_logprob=attention_logprobs, in_lens=enc_len, out_lens=dec_len + ) + else: + attention_probs = None + # project decoder output to vocabulary-size dimensions + if self.share_decoder_tokens_head_embeddings: + first_layer_vocabsize = ( + self.speech_offset + self.speech_codebook_size + ) # variables set in __init__ of speechlm model + token_logits = self.tokens_head(dec_output, self.word_embeddings_weight()) # s, b, vocab + if self.seq_pattern in ["parallel", "delay_parallel"]: + # For flat seq_pattern we need all the logits + token_logits = token_logits[:, :, :first_layer_vocabsize] + speech_layers = self.num_speech_codebooks - 1 + + # speech_logits_list will be used in loss calculation (parallel output) + speech_logits_list = [] + if self.seq_pattern in ["parallel", "delay_parallel"] and torch.count_nonzero(speech_mask) > 0: + for i in range(speech_layers): + last_layer_logits = self.speech_tokens_heads[i](dec_output)[0] # T, B, 1024 + speech_logits_list.append(last_layer_logits) # T, B, 1024 + else: + token_logits = self.tokens_head(dec_output)[0] # T, B, WordEmbSize + + if labels is not None: + if labels.dim() == 2: + # [b, s] -> [s, b] + labels = labels.transpose(0, 1).contiguous() + elif labels.dim() == 3: + # [b, c, s] -> [c, s, b] + labels = labels.permute(1, 2, 0).contiguous() + + # Set label smoothing to 0 if in eval mode. + label_smoothing = self.label_smoothing if self.training else 0.0 + + # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i + if self.fp16_cross_entropy: + assert token_logits.dtype == torch.half + if labels.dim() == 3: + raise NotImplementedError("fp16_cross_entropy is not support for labels of dimension 3") + tokens_loss = vocab_parallel_cross_entropy(token_logits, labels, label_smoothing) + else: + if labels.dim() == 2: + tokens_loss = vocab_parallel_cross_entropy(token_logits.float(), labels, label_smoothing) + elif labels.dim() == 3: + if token_logits.size()[0] != labels[0, :, :].size()[0]: + raise Exception("TODO: add a permute") + tokens_loss = vocab_parallel_cross_entropy( + token_logits.float(), labels[0, :, :], label_smoothing + ) + logging.debug(f"token_loss: {tokens_loss}") + logging.debug(f"token_loss: {torch.all(torch.isfinite(tokens_loss))}") + if ( + self.seq_pattern in ["parallel", "delay_parallel"] + and torch.count_nonzero(speech_mask) > 0 + ): + for i in range(speech_layers): + if speech_logits_list[i].size()[0] != labels[i + 1, :, :].size()[0]: + raise Exception("TODO: add a permute") + curr_codebook_loss = ( + vocab_parallel_cross_entropy( + speech_logits_list[i].float(), labels[i + 1, :, :], label_smoothing + ) + * speech_mask.T + ) + tokens_loss += curr_codebook_loss + logging.debug(f"token_loss_{i}: {tokens_loss}") + logging.debug(f"token_loss_{i}: {torch.all(torch.isfinite(tokens_loss))}") + + # [s, b] -> [b, s] + tokens_loss = tokens_loss.transpose(0, 1).contiguous() + + # check if hiddens is used + if self.hiddens_cfg is not None: + raise NotImplementedError("Not currently implemented for speechllm") + else: + return tokens_loss, [token_logits, speech_logits_list, attention_probs, alignment_loss] + else: + # else return token logits (and hiddens if needed) + # [s, b, h] -> [b, s, h] + # If labels is None then we are in inference mode and we return the gathered logits + if self.parallel_output: + # Gather logits from tensor parallel if in parallel_output mode + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region( + token_logits + ) # T, B, 30208 + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) # T, B, 1024 + + token_logits = token_logits.transpose(0, 1).contiguous() # (B, T, 30208) + speech_logits = torch.stack(speech_logits_list, dim=-1) # T, B, 1024, 7 + speech_logits = speech_logits.transpose(0, 1).contiguous() # (B, T, 1024, 7) + + _si = self.speech_offset + _ei = _si + self.speech_codebook_size + first_layer_speech_logits = token_logits[:, :, _si:_ei].unsqueeze(-1) # (b, s, 1023, 1) + + all_speech_logits = torch.cat( + [first_layer_speech_logits, speech_logits], dim=-1 + ) # (b, s, 1024, 8) + + if self.hiddens_cfg is not None: + raise NotImplementedError("Not currently implemented for speechllm") + else: + # all_speech_logits: tensor, (b, s, 1024, 8), all layers of speech. + # token_logits: tensor, (b, s, vocab_size), text token logits. + # speech_logits: tensor, (b, s, 1024, 7), 1-7 layers of speech. + # attention_probs: tensor or None, (b, s, ) + # enc_output: tensor, (virtual_token_len+context_token_len+question_token_len+extra_id_0+[SEP], b, ) + return all_speech_logits, [token_logits, speech_logits, attention_probs, enc_output] + + elif self.add_decoder and not self.add_encoder: + decoder_output, _ = output + return decoder_output + else: + encoder_output = output + return encoder_output + + def state_dict(self): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._encoder_embedding_key] = self.encoder_embedding.state_dict() + state_dict_[self._decoder_embedding_key] = self.decoder_embedding.state_dict() + state_dict_[self._enc_dec_model_key] = self.enc_dec_model.state_dict() + state_dict_[self._tokens_head_key] = self.tokens_head.state_dict() + if hasattr(self, "speech_tokens_heads"): + state_dict_["speech_tokens_heads"] = self.speech_tokens_heads.state_dict() + if hasattr(self, "speech_tokens_embeddings"): + state_dict_["speech_tokens_embeddings"] = self.speech_tokens_embeddings.state_dict() + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + super().load_state_dict(state_dict, strict=strict) + if hasattr(self, "speech_tokens_heads"): + self.speech_tokens_heads.load_state_dict(state_dict["speech_tokens_heads"], strict=strict) + if hasattr(self, "speech_tokens_embeddings"): + self.speech_tokens_embeddings.load_state_dict(state_dict["speech_tokens_embeddings"], strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index ab10b0d0e8b3..c5108d8e3801 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn from einops import rearrange +from omegaconf.listconfig import ListConfig from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( @@ -479,6 +480,10 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_core_attention=False, + return_crossattention_scores=False, + return_selfattention_scores=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, ): # Self attention. if rotary_pos_emb is not None: @@ -489,6 +494,12 @@ def forward( self_attention_pos_emb = None cross_attention_pos_emb = None + if return_crossattention_scores and return_selfattention_scores: + raise NotImplementedError( + "We can only return 1 of cross attention scores or self attention scores. Not both yet." + ) + attention_probs = None + if self.layer_type != LayerType.retrieval_decoder_after_self_attn: # hidden_states: [b, s, h] @@ -507,12 +518,16 @@ def forward( layer_past=layer_past, get_key_value=get_key_value, set_inference_key_value_memory=set_inference_key_value_memory, - inference_max_sequence_len=inference_max_sequence_len, + inference_max_sequence_len=inference_max_sequence_len or decoder_max_sequence_len, rotary_pos_emb=self_attention_pos_emb, relative_position_bias=self_attention_relative_position_bias, checkpoint_core_attention=checkpoint_core_attention, + return_scores=return_selfattention_scores, ) + if return_selfattention_scores: + attention_output, attention_probs = attention_output + if get_key_value: attention_output, presents = attention_output @@ -526,7 +541,7 @@ def forward( attention_bias = None # jit scripting for a nn.module (with dropout) is not - # trigerring the fusion kernel. For now, we use two + # triggering the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. @@ -553,6 +568,9 @@ def forward( elif self.transformer_block_type in ['pre_ln', 'normformer']: # Layer norm post the self attention. normalization_output = self.post_attention_layernorm(layernorm_input) + else: + normalization_output = None + logging.warning(f"This is a rare case since `normalization_output=None`") else: layernorm_input, normalization_output = hidden_states @@ -579,7 +597,7 @@ def forward( checkpoint_core_attention=checkpoint_core_attention, ) else: - + # Return Scores is being passed only for inter_attention and not self attention attention_output, attention_bias = self.inter_attention( normalization_output, enc_dec_attn_mask, @@ -587,7 +605,12 @@ def forward( rotary_pos_emb=cross_attention_pos_emb, relative_position_bias=cross_attention_relative_position_bias, checkpoint_core_attention=checkpoint_core_attention, + return_scores=return_crossattention_scores, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=encoder_max_sequence_len, ) + if return_crossattention_scores: + attention_output, attention_probs = attention_output # If normformer, apply norm on the output of the self attention. if self.transformer_block_type == 'normformer': @@ -632,6 +655,9 @@ def forward( if get_key_value: output = [output, presents] + if attention_probs is not None: + output = [output, attention_probs] + return output @@ -735,6 +761,10 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_core_attention=False, + return_crossattention_scores=False, + return_selfattention_scores=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, ): if self.dtype == torch.float32: return super().forward( @@ -750,6 +780,10 @@ def forward( self_attention_relative_position_bias, cross_attention_relative_position_bias, checkpoint_core_attention, + return_crossattention_scores=return_crossattention_scores, + return_selfattention_scores=return_selfattention_scores, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, ) with torch.autocast(device_type="cuda", dtype=self.dtype): return super().forward( @@ -765,6 +799,10 @@ def forward( self_attention_relative_position_bias, cross_attention_relative_position_bias, checkpoint_core_attention, + return_crossattention_scores=return_crossattention_scores, + return_selfattention_scores=return_selfattention_scores, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, ) @@ -1072,10 +1110,12 @@ def __init__( # Transformer layers. def build_layer(layer_number): - if isinstance(layer_type, list): + if isinstance(layer_type, (list, ListConfig)): lt = layer_type[layer_number - 1] else: lt = layer_type + if isinstance(lt, int): + lt = LayerType(lt) if self.transformer_engine: transformer_layer_args = { @@ -1493,7 +1533,16 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_activations_all_layers=None, + return_all_crossattention_probs=False, + return_all_selfattention_probs=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): + if return_all_crossattention_probs and return_all_selfattention_probs: + raise NotImplementedError( + "We can only return 1 of cross attention probs or self attention probs. Not both yet." + ) # Checks. if inference_max_sequence_len: assert self.activations_checkpoint_method is None, 'inference does not work with activation checkpointing' @@ -1580,6 +1629,7 @@ def forward( if self.inference_params != None: self.inference_params.sequence_len_offset = self.inference_current_sequence_len + attention_probs_list = [] if self.return_select_layer < 0: assert ( parallel_state.get_pipeline_model_parallel_world_size() == 1 @@ -1588,10 +1638,32 @@ def forward( logging.warning("Returning embeddings states only!") return hidden_states + layer_to_encoder_num_mapping = {} + if enc_output_to_layers is not None: + assert len(enc_output_to_layers) == len(encoder_output) + for encoder_idx in range(len(encoder_output)): + for layer_idx in enc_output_to_layers[encoder_idx]: + layer_to_encoder_num_mapping[layer_idx] = encoder_idx + for index in range(self.num_layers): layer = self._get_layer(index) past = None + _encoder_output = encoder_output + _enc_dec_attn_mask = enc_dec_attn_mask + _cross_attention_relative_position_bias = cross_attention_relative_position_bias + _encoder_max_sequence_len = encoder_max_sequence_len + if index in layer_to_encoder_num_mapping: + _encoder_output = encoder_output[layer_to_encoder_num_mapping[index]] + _enc_dec_attn_mask = enc_dec_attn_mask[layer_to_encoder_num_mapping[index]] + _cross_attention_relative_position_bias = cross_attention_relative_position_bias[ + layer_to_encoder_num_mapping[index] + ] + if encoder_max_sequence_len is not None: + _encoder_max_sequence_len = encoder_max_sequence_len[ + layer_to_encoder_num_mapping[index] + ] + if layer_past is not None: past = layer_past[index] @@ -1625,27 +1697,65 @@ def forward( hidden_states = layer( hidden_states, attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, inference_params=self.inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, ) else: - hidden_states = layer( - hidden_states, - attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - layer_past=past, - get_key_value=get_key_value, - set_inference_key_value_memory=set_inference_key_value_memory, - inference_max_sequence_len=inference_max_sequence_len, - rotary_pos_emb=rotary_pos_emb, - self_attention_relative_position_bias=self_attention_relative_position_bias, - cross_attention_relative_position_bias=cross_attention_relative_position_bias, - checkpoint_core_attention=checkpoint_core_attention, - ) + if layer.layer_type == LayerType.decoder and return_all_crossattention_probs: + hidden_states, attention_probs = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + return_crossattention_scores=return_all_crossattention_probs, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=_encoder_max_sequence_len, + ) + attention_probs_list.append(attention_probs) + elif layer.layer_type == LayerType.encoder and return_all_selfattention_probs: + hidden_states, attention_probs = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + get_key_value=get_key_value, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + return_selfattention_scores=return_all_selfattention_probs, + ) + attention_probs_list.append(attention_probs) + else: + hidden_states = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + get_key_value=get_key_value, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=_encoder_max_sequence_len, + ) if self.return_select_layer < 0: assert ( @@ -1679,4 +1789,7 @@ def forward( if get_key_value: output = [output, presents] + if return_all_crossattention_probs or return_all_selfattention_probs: + output = [output, attention_probs_list] + return output diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 601cb7a4d7e8..b0a6f755a9cc 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -18,7 +18,6 @@ from typing import Dict, Iterator, List, Optional, Tuple, Union import torch -import torch.nn as nn from torch import Tensor from nemo.utils import logging, logging_mode @@ -474,9 +473,25 @@ def get_iterator_k_split( else: # Split a list of torch tensors assert batch[0].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" - split_batch = [ - torch.tensor_split(item, num_microbatches, dim=0) if torch.is_tensor(item) else item for item in batch - ] + split_batch = [] + for item in batch: + if torch.is_tensor(item): + split_batch.append(torch.tensor_split(item, num_microbatches, dim=0)) + elif isinstance(item, list): + if isinstance(item[0], torch.Tensor): + split_tensors = [torch.tensor_split(elem, num_microbatches, dim=0) for elem in item] + split_tuple = [] + for mbi in range(num_microbatches): + split_tuple.append([split_tensors[i][mbi] for i in range(len(split_tensors))]) + split_tuple = tuple(split_tuple) + split_batch.append(split_tuple) + else: + split_batch.append(split_list(item, num_microbatches)) + elif item is None: + split_batch.append(item) + else: + raise ValueError(f"Unsupported item type: {type(item)}") + microbatches = [ [elem[i] if elem is not None else elem for elem in split_batch] for i in range(num_microbatches) ] diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index eeaaea26beac..4743c3216e6a 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -24,7 +24,7 @@ import numpy as np import torch import torch.nn.functional as F -from lightning_fabric.utilities.seed import seed_everything +from lightning.fabric.utilities.seed import seed_everything from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer from nemo.collections.multimodal.data.neva.conversation import ( diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index 7c7360ba3400..11f79baa819a 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -15,11 +15,11 @@ import sys from typing import Optional, Union -from lightning_fabric.utilities.exceptions import MisconfigurationException +from lightning.fabric.utilities.exceptions import MisconfigurationException +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback from nemo.collections.nlp.parts.nlp_overrides import ( diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 2100e9c1ba8f..c16116145a12 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -23,24 +23,24 @@ from pathlib import Path from typing import Any, Callable, Dict, Generator, Iterator, List, Literal, Mapping, Optional, Sized, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from lightning_fabric.plugins import TorchCheckpointIO -from lightning_fabric.utilities.cloud_io import get_filesystem -from lightning_fabric.utilities.optimizer import _optimizer_to_device +from lightning.fabric.plugins import TorchCheckpointIO +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.optimizer import _optimizer_to_device +from lightning.pytorch.callbacks.progress import TQDMProgressBar +from lightning.pytorch.callbacks.progress.tqdm_progress import _update_n +from lightning.pytorch.core.optimizer import LightningOptimizer +from lightning.pytorch.loops.fetchers import _DataFetcher +from lightning.pytorch.plugins import ClusterEnvironment +from lightning.pytorch.plugins.io.checkpoint_plugin import CheckpointIO +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision +from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.callbacks.progress import TQDMProgressBar -from pytorch_lightning.callbacks.progress.tqdm_progress import _update_n -from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.loops.fetchers import _DataFetcher -from pytorch_lightning.plugins import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.plugins.precision.fsdp import FSDPPrecision -from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.trainer.trainer import Trainer from torch._C._distributed_c10d import ReduceOp from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.distributed.fsdp import BackwardPrefetch, FullStateDictConfig @@ -107,6 +107,7 @@ from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer + from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO HAVE_MEGATRON_CORE = True diff --git a/nemo/collections/nlp/parts/utils_funcs.py b/nemo/collections/nlp/parts/utils_funcs.py index a989ff3f606c..87fc1aa6f73c 100644 --- a/nemo/collections/nlp/parts/utils_funcs.py +++ b/nemo/collections/nlp/parts/utils_funcs.py @@ -28,9 +28,9 @@ import numpy as np import torch import torch.nn.functional as F +from lightning.pytorch.trainer.trainer import Trainer from matplotlib import pyplot as plt from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from sklearn.metrics import classification_report, confusion_matrix from torch import Tensor diff --git a/nemo/collections/tts/data/speechllm/__init__.py b/nemo/collections/tts/data/speechllm/__init__.py new file mode 100644 index 000000000000..9df65818d226 --- /dev/null +++ b/nemo/collections/tts/data/speechllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py new file mode 100644 index 000000000000..32f0a14f5e65 --- /dev/null +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py @@ -0,0 +1,1355 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import json +import random +from dataclasses import dataclass +from pathlib import Path +from typing import ClassVar, List, Optional, Union + +import numpy as np +import torch +from hydra.utils import instantiate +from omegaconf import OmegaConf +from tqdm.auto import tqdm + +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.common.tokenizers.text_to_speech.ipa_lexicon import get_ipa_punctuation_list +from nemo.collections.common.tokenizers.text_to_speech.tokenizer_utils import any_locale_text_preprocessing +from nemo.collections.nlp.data.language_modeling.megatron.base_prompt_learning_dataset import BasePromptLearningDataset +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import T5Sentinel +from nemo.collections.nlp.modules.common import VirtualPromptSource +from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + BetaBinomialInterpolator, + beta_binomial_prior_distribution, + general_padding, + get_base_dir, +) +from nemo.utils import logging + +__all__ = ['T5SpeechLMDataset', "Lang"] + + +def get_full_list_puncts(): + punct_set = set() + for locale_id in ["en-US", "de-DE", "fr-FR"]: + punct_list = get_ipa_punctuation_list(locale=locale_id) + punct_set.update(punct_list) + return sorted(punct_set) + + +@dataclass +class G2PConfig: + _target_: str = "nemo.collections.tts.g2p.models.en_us_arpabet.EnglishG2p" + phoneme_dict: str = "scripts/tts_dataset_files/cmudict-0.7b_nv22.10" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + + +@dataclass +class EnglishIpaG2pConfig: + _target_: str = "nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p" + phoneme_dict: str = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + locale: str = "en-US" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + grapheme_case: str = "upper" + use_stresses: bool = True + use_chars: bool = True + ignore_ambiguous_words: bool = False + + +@dataclass +class TextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.EnglishPhonemesTokenizer" + punct: bool = True + stresses: bool = True + chars: bool = True + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: G2PConfig = G2PConfig() + + +@dataclass +class EnglishIpaTextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer" + locale: str = "en-US" + punct: bool = True + # Define non_default_punct_list as a ClassVar to explicitly mark it as a class variable + non_default_punct_list: ClassVar[List[str]] = get_full_list_puncts() + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: EnglishIpaG2pConfig = EnglishIpaG2pConfig() + + +@dataclass +class TextTokenizerConfig: + text_tokenizer: TextTokenizer = TextTokenizer() + + +@dataclass +class EnglishIpaTextTokenizerConfig: + text_tokenizer: EnglishIpaTextTokenizer = EnglishIpaTextTokenizer() + + +def _get_default_text_tokenizer_conf(phoneme_probability: float = 0.5, use_ipa: bool = False): + if use_ipa: + g2p = EnglishIpaG2pConfig(phoneme_probability=phoneme_probability) + _text_tokenizer = EnglishIpaTextTokenizer(g2p=g2p) + text_tokenizer: EnglishIpaTextTokenizerConfig = EnglishIpaTextTokenizerConfig(text_tokenizer=_text_tokenizer) + else: + g2p = G2PConfig(phoneme_probability=phoneme_probability) + _text_tokenizer = TextTokenizer(g2p=g2p) + text_tokenizer: TextTokenizerConfig = TextTokenizerConfig(text_tokenizer=_text_tokenizer) + return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer)) + + +def pad_text_to_speech_dims(text_tensor, pad_id, pad_size=7): + token_len = text_tensor.shape[0] + empty_padding = torch.ones((pad_size, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id + return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0) + + +class Lang(enum.Enum): + en = 1 + es = 2 + fr = 3 + zh = 4 + de = 4 + + +class T5SpeechLMDataset(BasePromptLearningDataset): + """ + The dataset class for prompt-tuning or p-tuning pretrained T5 SpeechLM models. + """ + + def __init__( + self, + datasets, + tokenizer, + virtual_prompt_source: VirtualPromptSource, + task_templates: dict, + pseudo_tokens, + pad_token_id: str, + max_seq_length: int, + sample_rate: int, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + for_train: bool = True, + decoder_starts_with_pad: bool = False, + add_eos_to_decoder_output: bool = True, + add_sentinel_to_input: bool = True, + ul2_prompt_token: str = None, + segment_max_duration: Optional[int] = None, + trim: bool = False, + trim_ref: Optional[float] = None, + trim_top_db: Optional[int] = None, + trim_frame_length: Optional[int] = None, + trim_hop_length: Optional[int] = None, + pad_multiple: int = 1, + pitch_augment: bool = False, + sup_data_path: Optional[Union[Path, str]] = None, + speech_offset: Optional[int] = None, + train_task: Optional[str] = None, + seq_pattern: Optional[str] = "parallel", + use_attention_prior: Optional[bool] = False, + attention_prior_scaling_factor: Optional[float] = 1.0, + spec_aug=False, + spec_aug_time_width=0.2, + spec_aug_time_masks=2, + cross_attention_epsilon: Optional[float] = 0.0, + lm_vocab_size: Optional[int] = None, + num_speech_codebooks: Optional[int] = 8, + codebook_fps: Optional[int] = 86, + add_special_tokens_to_only_first_codebook: Optional[bool] = False, + context_pattern: Optional[str] = "parallel", + context_duration_min: Optional[float] = 3.0, + context_duration_max: Optional[float] = 5.0, + skip_datasets: Optional[List[str]] = [], # substrings of dataset names to skip + english_only_model: Optional[bool] = False, + context_conditioning: Optional[str] = "decoder", # encoder or decoder + use_beta_binomial_interpolator: Optional[str] = False, # encoder or decoder + context_slice_method: Optional[str] = "random", # random or fixed + phoneme_probability: Optional[float] = 0.5, + encoder_type: Optional[str] = "single_transformer", + use_ipa: bool = False, + **kwargs, + ): + """ + Only speech parameters are explained here. + segment_max_duration: Optional[int] = None, - Speech max segment duration + trim: bool = False, - speech parameter + trim_ref: Optional[float] = None, - speech parameter + trim_top_db: Optional[int] = None, - speech parameter + trim_frame_length: Optional[int] = None, - speech parameter + trim_hop_length: Optional[int] = None, - speech parameter + pad_multiple: int = 1, - speech parameter + pitch_augment: bool = False, - speech parameter + sup_data_path: Optional[Union[Path, str]] = None, - Supplementary folder path where codecs are stored. + speech_offset: Optional[int] = None, - if speech tokens then add this offset to the token indices to distinguish between text and speech tokens. + lm_vocab_size: Optional[int] = None, - vocab size of the original language model (phoneme tokens start from this index) + english_only_model: Optional[bool] = False, specify if monolingual or multi-lingual modeling. + use_ipa: bool = False, specify if using IPA tokens or default ARPABET tokens. Either choice still mixes chars. + **kwargs, + """ + # These two variables need to be set before calling super().__init__() because the parent class calls `load_data()` which requires these attributes. + self._rng = random.Random() + self.spec_aug = spec_aug if for_train else False + self.time_width = spec_aug_time_width + self.time_masks = spec_aug_time_masks + self.decoder_starts_with_pad = decoder_starts_with_pad + self.add_eos_to_decoder_output = add_eos_to_decoder_output + self.add_sentinel_to_input = add_sentinel_to_input + self.ul2_prompt_token = ul2_prompt_token + # Speech related variables + self.base_data_dir = None + self.segment_max_duration = segment_max_duration + self.sample_rate = sample_rate + self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate) + self.pad_multiple = pad_multiple + self.pitch_augment = pitch_augment + self.trim = trim + self.trim_ref = trim_ref if trim_ref is not None else np.max + self.trim_top_db = trim_top_db if trim_top_db is not None else 60 + self.trim_frame_length = trim_frame_length if trim_frame_length is not None else 2048 + self.trim_hop_length = trim_hop_length if trim_hop_length is not None else 512 + self.speech_offset = speech_offset if speech_offset is not None else 3 + self.seq_pattern = seq_pattern + self.use_attention_prior = use_attention_prior + self.attention_prior_scaling_factor = attention_prior_scaling_factor + self.cross_attention_epsilon = cross_attention_epsilon # value of prior for context tokens (b/w 0 and 1) + assert self.cross_attention_epsilon >= 0.0 and self.cross_attention_epsilon <= 1.0 + self.lm_vocab_size = tokenizer.vocab_size if lm_vocab_size is None else lm_vocab_size + self.num_speech_codebooks = num_speech_codebooks + self.codebook_fps = codebook_fps + self.add_special_tokens_to_only_first_codebook = add_special_tokens_to_only_first_codebook + # context_pattern and duration arguments are supported only if context_type is REFSPEAKERCODEC in the manifest + self.context_pattern = context_pattern + self.context_duration_min = context_duration_min + self.context_duration_max = context_duration_max + self.english_only_model = english_only_model + self.phoneme_tokenizer = None + if english_only_model: + self.phoneme_tokenizer = instantiate( + _get_default_text_tokenizer_conf(phoneme_probability=phoneme_probability, use_ipa=use_ipa) + ).text_tokenizer + else: + self.g2p = {"fr": lambda x: x} + if kwargs.get("g2p", None): + if "english" in kwargs["g2p"]: + english_g2p = instantiate(kwargs["g2p"]["english"]) + self.g2p["en"] = lambda x: english_g2p(x) + if "spanish" in kwargs["g2p"]: + spanish_g2p = instantiate(kwargs["g2p"]["spanish"]) + self.g2p["es"] = lambda x: spanish_g2p(x) + if "mandarin" in kwargs["g2p"]: + mandarin_g2p = instantiate(kwargs["g2p"]["mandarin"]) + self.g2p["zh"] = lambda x: mandarin_g2p(x) + if "german" in kwargs["g2p"]: + german_g2p = instantiate(kwargs["g2p"]["german"]) + self.g2p["de"] = lambda x: german_g2p(x) + + self.context_conditioning = context_conditioning + if self.context_conditioning == "decoder": + assert ( + self.context_duration_min == self.context_duration_max + ), "For decoder conditioning, context_duration_min and context_duration_max should be same" + self.decoder_context_len = int( + self.context_duration_min * self.codebook_fps + ) # TODO: Just take from model var? + + # Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type\ + self.sup_data_path = None + if sup_data_path is not None: + Path(sup_data_path).mkdir(parents=True, exist_ok=True) + self.sup_data_path = sup_data_path + + self.codec_folder = kwargs.pop('codec_folder', None) + self.train_task = train_task + if self.codec_folder is None and self.sup_data_path is not None: + self.codec_folder = Path(self.sup_data_path) / "codec" + elif isinstance(self.codec_folder, str): + self.codec_folder = Path(self.codec_folder) + + self.codec_folder.mkdir(exist_ok=True, parents=True) + + self.context_length = kwargs.pop('context_length', None) # only used in gpt dataset atm + # self.attention_prior_strength = attention_prior_strength + self.transformer_type = kwargs.pop('transformer_type', 'T5') + self.skip_datasets = skip_datasets + + self.beta_binomial_interpolator = ( + BetaBinomialInterpolator(scaling_factor=self.attention_prior_scaling_factor) + if use_beta_binomial_interpolator + else None + ) + self.context_slice_method = context_slice_method + self.encoder_type = encoder_type + super().__init__( + datasets=datasets, + tokenizer=tokenizer, + virtual_prompt_source=virtual_prompt_source, + task_templates=task_templates, + pseudo_tokens=pseudo_tokens, + pad_token_id=pad_token_id, + max_seq_length=max_seq_length, + min_seq_length=min_seq_length, + add_bos=add_bos, + add_eos=add_eos, + for_train=for_train, + ) + + def load_data(self, dataset): + """ + Loads a dataset by filling in the task templates specified in the config file + with the information from each training/inference example. Converts all input + text into token ids. Also replaces the <|VIRTUAL_PROMPT_#|> placeholders in + the task templates with the actual virtual prompt token ids. + + params: + dataset: A list of json objects or a dictionary objects each + containing the information needed for a training example + """ + copy_dataset = list(dataset) + audio_filelist = [] + # This loop is needed to calculate self.base_data_dir. + for json_line in copy_dataset: + if type(json_line) == dict: + doc = json_line + else: + doc = json.loads(json_line) + taskname = doc["taskname"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + + for p in prompt_template_fields: + if f"{p}_type" in doc and doc[f"{p}_type"] == "SPEECH": + audio_filelist.append(doc[p]) + self.base_data_dir = get_base_dir(audio_filelist) + + skipped = 0 + tts = 0 + asr = 0 + i = 0 + logging.info(f"copy_dataset len === {len(copy_dataset)}") + examples = [] + for json_line in tqdm(copy_dataset): + i += 1 + + # Read example dict or load the information for a single example from .json file + if type(json_line) == dict: + doc = json_line + else: + doc = json.loads(json_line) + + if self.context_conditioning == "decoder": + # Modify doc to make combine context and anwer + assert ";" not in doc['context'], "Multiple contexts not supported in decoder conditioning" + doc['answer'] = "{};{}".format(doc['context'], doc['answer']) + doc['answer_duration'] = self.context_duration_min + doc['answer_duration'] + doc['answer_type'] = "CONTEXTANSWER" + doc['context_type'] = "DUMMYCONTEXT" + doc['context'] = "DUMMYCONTEXT" + + question_in_manifest = doc['question'] + + if "Text to speech this" in question_in_manifest or "Phoneme TTS" in question_in_manifest: + tts += 1 + if self.train_task not in ['tts', 'all']: + continue + elif "Next token prediction" in question_in_manifest: + if self.train_task != 'tts': + asr += 1 + else: + tts += 1 + continue + else: + if self.train_task == 'tts': + continue + asr += 1 + + if doc["context_type"] == "SPEECH": + assert "context_duration" in doc, f"context_duration key not in document {doc}" + approx_context_len = 3 * (self.codebook_fps + 1) # +1 just to be safe + if self.context_length is not None and doc["context_duration"] < self.context_length: + logging.debug( + f"skipped as context_length of {doc['context_duration']} is less than {self.context_length}" + ) + skipped += 1 + continue + elif "Remove Noise" in question_in_manifest: + approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1) + elif "Extract Speaker Audio" in question_in_manifest: + approx_context_len = ( + doc["answer_duration"] * (self.codebook_fps + 1) + 400 + ) # 400 is the max ref speaker audio + elif ("Text to speech this" in question_in_manifest) or ('Phoneme TTS' in question_in_manifest): + # approx_context_len = 400 + approx_context_len = 5 * ( + self.codebook_fps + 1 + ) # better than 400. TODO: pneekhara: Need to change things for multi-encoder vs single encoder based filtering. + elif "Edit Speech" in question_in_manifest: + approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1) + else: + raise NotImplementedError(f"Unknown context type {doc['context_type']}") + + approx_question_len = len(doc["question"].split(' ')) + 3 + if 'Phoneme TTS' in question_in_manifest: + # approx len is equal to num of characters + approx_question_len = len(question_in_manifest) + + if doc["answer_type"] in ["SPEECH", "AUDIOCODEC", "CONTEXTANSWER"]: + assert "answer_duration" in doc, f"answer_duration key not in document {doc}" + approx_answer_len = doc["answer_duration"] * (self.codebook_fps + 1) + 3 # +3 for EOS, BOS padding + if self.seq_pattern == "delay_parallel": + # In delay parallel, there is padding so add 8 frames + approx_answer_len = approx_answer_len + self.num_speech_codebooks + else: + approx_answer_len = len(doc["answer"].split(' ')) + 3 + + skip_record = False + for skip_dataset in self.skip_datasets: + if skip_dataset in doc['answer']: + skip_record = True + + if not skip_record: + if (self.transformer_type == "GPT") and ( + self.min_seq_length + < approx_context_len + approx_question_len + approx_answer_len + < self.max_seq_length + ): + examples.append(doc) + elif (self.transformer_type == "T5") and ( + self.min_seq_length < approx_context_len + approx_question_len < self.max_seq_length + and self.min_seq_length < approx_answer_len < self.max_seq_length + ): + examples.append(doc) + else: + logging.debug(f"skipped for {approx_context_len + approx_question_len} {approx_answer_len} len") + skipped += 1 + else: + print("Skipping", doc['answer']) + logging.debug(f"skipped for {doc['answer']} as it is in skip_datasets") + skipped += 1 + + logging.info(f'Skipped {skipped} sentences, sequence length too short or too long even after truncation') + + return examples + + def __getitem__(self, idx): + doc = self.examples[idx] + taskname = doc["taskname"] + prompt_template = self.task_templates[taskname]["prompt_template"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"] + virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] + truncation_field = self.task_templates[taskname]['truncate_field'] + answer_field = self.task_templates[taskname]["answer_field"] + + input_example = prompt_template + + self._input_sanity_checks( + total_virtual_tokens=total_virtual_tokens, + virtual_token_splits=virtual_token_splits, + prompt_template=prompt_template, + prompt_template_fields=doc.keys(), # Skip this check as we don't need it for TTS + truncation_field=truncation_field, + answer_field=answer_field, + doc=doc, + ) + question_in_manifest = doc['question'] + + # Format the input example according to the template + # Get context, question and answer codes in a dict. + # TODO @xueyang: declare the instructions when initializing the dataset so that they can be re-used. Temporally + # hardcode them here. + question_text = doc["question"].strip() + instructions = ["Phoneme TTS", "Text to speech this"] + for prefix in instructions: + if doc["question"].startswith(prefix): + question_text = doc["question"][len(prefix) :].strip() + break + + input_dict = self._insert_data_in_template(prompt_template_fields, doc, answer_field) + lang = Lang[doc.get("lang", "en")] + context_tokens = input_dict['context'] + question_tokens = input_dict['question'] + + # Logic to prune context + # In case of TTS task, the entire reference speech is not required, so we randomly select a portion + # of the reference audio. + # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be + # predicted by the decoder. + start_token_index = 0 + end_token_index = -1 + if ("Text to speech this" in question_in_manifest) and (doc["context_type"] == "SPEECH"): + total_context_len = context_tokens[0].size()[1] + reduced_len = min( + 400, + ( + int(total_context_len * 0.2) + if total_context_len > 600 + else int(total_context_len * random.uniform(0.2, 0.5)) + ), + ) + start_token_index = random.randint( + 0, total_context_len - reduced_len + ) # start index can be greater than 440 + context_tokens[0] = context_tokens[0][ + :, start_token_index : min(start_token_index + 440, start_token_index + reduced_len) + ] + elif "Next token prediction" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + end_token_index = int(total_context_len * random.uniform(0.01, 0.2)) + context_tokens[0] = context_tokens[0][:, :end_token_index] + + # Get virtual tokens + # `virtual_tokens` is "". + virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) + + # a trick to align with the data format in t5 pretraining + virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) + if self.add_sentinel_to_input: + question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + + # Add BOS/EOS to the input of encoder if desired, adds EOS by default + if self.ul2_prompt_token is not None: + ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) + assert len(ul2_prompt_token_id) == 1 + context_tokens = ul2_prompt_token_id + context_tokens + if self.add_bos: + context_tokens = [self.tokenizer.bos_id] + context_tokens + if self.add_eos: + question_tokens = question_tokens + [self.tokenizer.eos_id] + + # Try to truncate input text to fit into the max sequence length + if self._get_len(context_tokens, question_tokens, virtual_tokens) > self.max_seq_length: + context_tokens, question_tokens, virtual_tokens = self._truncate_input_speech( + context_tokens, question_tokens, virtual_tokens + ) + + virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) + context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) + question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) + + if doc["question_type"] == "TEXT" and doc["context_type"] != "TEXT": + question_tokens = pad_text_to_speech_dims( + question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + if doc["context_type"] == "TEXT" and doc["question_type"] != "TEXT": + context_tokens = pad_text_to_speech_dims( + context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + if doc["context_type"] == "TEXT" and doc["question_type"] == "TEXT": + context_tokens = pad_text_to_speech_dims( + context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + question_tokens = pad_text_to_speech_dims( + question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + + # context_tokens: tensor, (num_speech_codebooks, audio_context_len) + # question_tokens: tensor, (num_speech_codebooks, instruction token len + question token len + 1 ( + 1 ([SEP])), only first row includes token ids while all other rows are all zeros (pad) + if self.encoder_type == "multi_transformer": + context_and_question_tokens = [context_tokens, question_tokens] + else: + context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1) + + # get answer ids + if answer_field in doc.keys(): # training and validation + answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) + if end_token_index > -1: + answer_ids[0] = answer_ids[0][:, end_token_index:] + + if self.decoder_starts_with_pad: + answer_text_ids = [self.tokenizer.pad_id] + else: + answer_text_ids = [self.tokenizer.bos_id] + # a trick to align with the data format in t5 pretraining + # if self.add_sentinel_to_input: + # answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + answer_text_ids += answer_ids + + if self.add_eos_to_decoder_output: + answer_text_ids += [self.tokenizer.eos_id] + else: + answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) + + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + taskname_id = self.tokenizer.text_to_ids(taskname) + elif ( + self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT + ): # TODO (@adithyare) this class and GPTPromptLearningDataset should be merged. + taskname_id = -1 + else: + raise ValueError("Invalid virtual prompt source specified") + + dec_input = None + dec_labels = None + + # if single-encoder and context_condition is decoder, answer_text_ids = [CLS_id, context audio code tensors, zero-pad, answer audio code tensor, SEP_id] + # if multi-encoder, answer_text_ids = [CLS_id, answer audio codec tensor, SEP_id], so dec_input will not include audio context anymore. + if answer_field in doc.keys(): # training and validation + dec_input = answer_text_ids[:-1] + dec_labels = answer_text_ids[1:] + + # if single-encoder and context_condition is decoder: + # dec_input: shape=(self.num_speech_codebooks, 1([CLS]) + len(context audio frames) + 1([PAD]) + len(answer audio frames)) + # dec_labels: shape=(self.num_speech_codebooks, len(context audio frames) + 1([PAD]) + len(answer audio frames) + 1([SEP])) + # if multi-encoder: + # dec_input: (num_speech_codebooks, 1([CLS]) + len(answer audio frames)) + # dec_labels: (num_speech_codebooks, len(answer audio frames) + 1([SEP])) + dec_input, dec_input_len = self.list_to_tensor(dec_input, True) + dec_labels, dec_labels_len = self.list_to_tensor(dec_labels, True) + is_speech = True if doc["answer_type"] != "TEXT" else False + if is_speech: + assert dec_input.dim() == 2 and dec_labels.dim() == 2 + if self.seq_pattern == "delay_parallel": + num_codebooks = dec_input.shape[0] + dec_input_padded = torch.cat( + [ + torch.zeros_like(dec_input[:, 0:num_codebooks]), + dec_input, + torch.zeros_like(dec_input[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_labels_padded = torch.cat( + [ + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + dec_labels, + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_input_new = [] + dec_labels_new = [] + for _c in range(self.num_speech_codebooks): + st = num_codebooks - _c + et_decoder_input = dec_input_padded.shape[1] - _c + et_decoder_labels = dec_labels_padded.shape[1] - _c + dec_input_new.append(dec_input_padded[_c, st:et_decoder_input]) + dec_labels_new.append(dec_labels_padded[_c, st:et_decoder_labels]) + dec_input = torch.stack(dec_input_new, dim=0) + dec_labels = torch.stack(dec_labels_new, dim=0) + dec_input_len = torch.tensor(dec_input.shape[1]).long() + dec_labels_len = torch.tensor(dec_labels.shape[1]).long() + + if self.encoder_type == "multi_transformer": + enc_len = question_tokens_len + virtual_tokens_len + else: + enc_len = context_tokens_len + question_tokens_len + virtual_tokens_len + # TODO: Remove hardcoding + start_of_question_offset = 4 # For both "Text to Speech this" and "Phoneme TTS" + end_of_question_offset = 2 + cross_attention_prior = torch.zeros(dec_labels_len, enc_len) + self.cross_attention_epsilon + if self.use_attention_prior: + prior_dec_len = dec_labels_len.item() + prior_dec_start_idx = 0 + if self.context_conditioning == "decoder": + prior_dec_len = dec_labels_len.item() - (self.decoder_context_len + 1) + prior_dec_start_idx = self.decoder_context_len + 1 + text_len = question_tokens_len.item() - start_of_question_offset - end_of_question_offset + audio_len = prior_dec_len + if self.beta_binomial_interpolator is not None: + cross_attention_question_prior = torch.from_numpy(self.beta_binomial_interpolator(audio_len, text_len)) + else: + cross_attention_question_prior = torch.from_numpy( + beta_binomial_prior_distribution( + text_len, + audio_len, + scaling_factor=self.attention_prior_scaling_factor, + ) + ) + if self.encoder_type == "multi_transformer": + cross_attention_prior[ + prior_dec_start_idx:, virtual_tokens_len + start_of_question_offset : -end_of_question_offset + ] = cross_attention_question_prior + else: + cross_attention_prior[ + prior_dec_start_idx:, + virtual_tokens_len + context_tokens_len + start_of_question_offset : -end_of_question_offset, + ] = cross_attention_question_prior + + if self.encoder_type == "multi_transformer": + context_and_question_len = [context_tokens_len, question_tokens_len] + else: + context_and_question_len = context_tokens_len + question_tokens_len + return ( + taskname_id, # List, only one item. token id for "squad" + virtual_tokens, # Tensor, shape=(3,). token id for ['', '', ''] + virtual_tokens_len, # tensor, 3 + context_tokens_len, # tensor, 1 + # tensor if single encoder and context_condition is encoder, shape=(self.num_speech_codebooks, 1(context) + question len + 1() + 1([SEP])). only first row includes token ids while all other rows are all zeros (pad). + # list if multi-encoder and context_condition is encoder. + context_and_question_tokens, + # tensor scalar if single encoder and context_condition is decoder, 1 + (question len + 1 + 1). + # list if multi-encoder and context_condition is encoder. + context_and_question_len, + dec_input, # tensor, shape=(self.num_speech_codebooks, 1 CLS + context audio frame len + 1 pad + answer audio frame len), first column is [CLS_id, 0*7]^T + dec_input_len, # scalar tensor, 1 CLS + context audio frame len + 1 pad + answer audio frame len. 1 corresponds to CLS id + dec_labels, # tensor, shape=(self.num_speech_codebooks, context audio frame len + 1 pad + answer frame len + 1 SEP). + dec_labels_len, # tensor, context audio frame len + 1 PAD + answer frame len + 1 SEP. 1 corresponds to SEP id. + is_speech, # True + cross_attention_prior, # tensor, shape=(dec_labels_len, context_tokens_len + question_tokens_len + virtual_tokens_len). + lang.value, # int, + question_text, # str, answer transcript without question type (Phoneme TTS or Text to speech this). + ) + + def _truncate_input_speech(self, context_tokens, question_tokens, virtual_tokens): + total_len = self._get_len(context_tokens, question_tokens, virtual_tokens) + context_len = self._get_element_len(context_tokens) + truncation_length = total_len - self.max_seq_length + 1 + context_tokens[0] = context_tokens[0][:, min(truncation_length, context_len) :] + return context_tokens, question_tokens, virtual_tokens + + def list_to_tensor(self, element, fill=False): + """ + Convert list to tensor. The list might contain integers, 2D-tensors (speech tokens) and combination of two. + If all of them are ints, simply convert to tensor + If combination of 2D-tensor and ints. Convert int to the dimension of the tensor. + example: [2, 4, 5] -> torch.tensor([2, 4, 5]) + example: [2, torch.tensor([[4, 5, 6], [6, 7, 8]])] -> torch.tensor( [[-1, 4, 5, 6], [2, 6, 7, 8]] ) + """ + ret, ln = None, None + if element is None: + return ret, ln + + max_len = max([1 if isinstance(item, int) else len(item) for item in element]) + if max_len == 1: + ret = torch.as_tensor(element).long() + ln = torch.tensor(ret.size()[0]).long() + else: + ret = [] + for e in element: + if isinstance(e, int): + tmp = torch.full((self.num_speech_codebooks, 1), e if fill else -1) + tmp[self.num_speech_codebooks - 1] = e + if self.add_special_tokens_to_only_first_codebook: + # Fill zeros in all other codebooks (to avoid out of range when getting embeddings) + tmp[1:] = 0 + else: + tmp = e + ret.append(tmp) + ret = torch.cat(ret, dim=1) + ln = torch.tensor(ret.size()[1]).long() + return ret, ln + + def _get_text_tokens(self, text): + input_ids = self.tokenizer.text_to_ids(text) + return input_ids + + def _get_phoneme_tokens(self, text, lang="en"): + if self.english_only_model: + input_ids = self.phoneme_tokenizer.encode(text) + input_ids_adjusted = [_id + self.lm_vocab_size for _id in input_ids] + return input_ids_adjusted + else: + text = any_locale_text_preprocessing(text) + input_ids = self.g2p[lang](text) + input_ids_adjusted = [] + for i in input_ids: + input_ids_adjusted.append(f"p{{{i}}}") + input_ids_adjusted = self.tokenizer.text_to_ids("".join(input_ids_adjusted)) + return input_ids_adjusted + + def _pad_wav_to_multiple(self, wav): + if self.pad_multiple > 1: + if wav.shape[0] % self.pad_multiple != 0: + wav = torch.cat( + [wav, torch.zeros(self.pad_multiple - wav.shape[0] % self.pad_multiple, dtype=torch.float)] + ) + return wav + + def _get_element_len(self, element): + length = 0 + if isinstance(element, list): + for e in element: + if isinstance(e, int): + length += 1 + else: + if e.dim() > 1: + length += e.size()[1] + else: + length += e.size()[0] + else: + if element.dim() > 1: + length += element.size()[1] + else: + length += element.size()[0] + return length + + def _get_len(self, context_tokens, question_tokens, virtual_tokens): + length = 0 + length += self._get_element_len(context_tokens) + length += self._get_element_len(question_tokens) + length += self._get_element_len(virtual_tokens) + return length + + def _load_audio(self, audio_filepath, dur=-1): + if self.segment_max_duration is not None and dur > 0 and dur > self.segment_max_duration: + # this case has been added for segmenting audio for speaker verification task of SSLDisentangler + n_segments = int(self.segment_max_duration * self.sample_rate) + features = AudioSegment.segment_from_file( + audio_filepath, target_sr=self.sample_rate, n_segments=n_segments, trim=self.trim + ) + + features = torch.tensor(features.samples) + if self.pad_multiple > 1: + features = self._pad_wav_to_multiple(features) + audio, audio_length = features, torch.tensor(features.shape[0]).long() + else: + features = self.featurizer.process( + audio_filepath, + trim=self.trim, + trim_ref=self.trim_ref, + trim_top_db=self.trim_top_db, + trim_frame_length=self.trim_frame_length, + trim_hop_length=self.trim_hop_length, + ) + + if self.pad_multiple > 1: + features = self._pad_wav_to_multiple(features) + + audio, audio_length = features, torch.tensor(features.shape[0]).long() + + return audio, audio_length + + def convert_audio(self, audio, sample_rate, target_sample_rate, target_channels): + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + assert audio.shape[1] in [1, 2], "Audio must be mono or stereo." + # assert sample_rate == target_sample_rate, "sample rate of FastPitch and Encodec model has to be same" + if target_channels == 2: + *shape, _, length = audio.shape + audio = audio.expand(*shape, target_channels, length) + return audio + + def get_codec(self, audio): + wav1 = self.convert_audio(audio, self.sample_rate, self.encodec_model.sample_rate, self.encodec_model.channels) + encoded_frames = self.encodec_model.encode(wav1) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) + return codes.squeeze(0) + + def get_quantizer_codebook(self, reference_codec, reference_codec_length): + out = torch.zeros((1, 128, reference_codec_length.item())) + for i in range(reference_codec.size()[0]): + out += self.encodec_model.quantizer.vq.layers[i].decode(reference_codec[i, :].unsqueeze(0)) + return out.squeeze(0) + + def _get_speech_tokens(self, audio_filepath, dur=-1): + # Let's keep audio name and all internal directories in rel_audio_path_as_text_id to avoid any collisions + rel_audio_path = Path(audio_filepath).relative_to(self.base_data_dir).with_suffix("") + rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_") + + # Load audio features + audio, audio_length = self._load_audio(audio_filepath, dur) + + # Convert to codes + codec_path = self.codec_folder / f"{rel_audio_path_as_text_id}.pt" + + if codec_path.exists(): + try: + codec_codes = torch.load(codec_path).long() + except Exception as e: + print(f"[ERROR IN LOADING {codec_path}] e") + codec_codes = self.get_codec(audio).long() + torch.save(codec_codes, codec_path) + else: + codec_codes = self.get_codec(audio).long() + torch.save(codec_codes, codec_path) + + # Convert codes to codes corresponding to megatron embedding layer + codec_codes[0] = (codec_codes[0] + self.speech_offset).long() + + return codec_codes + + def _get_tokens(self, doc, field, field_data): + if self.context_slice_method == "random": + # During training, we want a random slice of the context + rng = random.Random() # Custom random generator (since random uses fixed seeds) + elif self.context_slice_method == "fixed": + # During inference, we want a fixed slice of the context + rng = random + else: + raise ValueError(f"Invalid context_slice_method {self.context_slice_method}") + if f"{field}_type" not in doc.keys(): + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'TEXT': + _text = field_data.strip(" ") + if _text.startswith("Phoneme TTS"): + lang = doc.get("lang", "en") + instruction_tokens = self._get_text_tokens("Phoneme TTS") + field_tokens = self._get_phoneme_tokens(_text[len("Phoneme TTS") :].strip(), lang=lang) + field_tokens = instruction_tokens + field_tokens + elif _text.startswith("Edit Speech"): + # Always use phoneme tokenizer for edit speech + instruction_tokens = self._get_text_tokens("Edit Speech") + field_tokens = self._get_phoneme_tokens(_text[len("Edit Speech") :].strip()) + field_tokens = instruction_tokens + field_tokens + elif _text.startswith("TEXT CONTEXT:"): + # Speaker id conditioning + field_tokens = self._get_text_tokens(_text) + # pad field tokens to fixed length + # assert self.context_duration_min == self.context_duration_max, "TEXT CONTEXT only supports fixed context duration" + # To keep context length the same for audio or tex context + # _fixed_context_len = int(self.context_duration_min * self.codebook_fps) + field_tokens = field_tokens + [self.tokenizer.eos_id] + else: + # if starts with Text to speech this + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'SPEECH': + dur = -1 + if f"{field}_duration" in doc: + dur = doc[f"{field}_duration"] + field_tokens = self._get_speech_tokens(field_data, dur) # list of ids + if not isinstance(field_tokens, list): + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'AUDIOCODEC': + reference_codec_paths = field_data.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + if self.codec_folder is not None: + reference_codec_path = self.codec_folder / reference_codec_path + field_tokens = torch.load(reference_codec_path).long() + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + # print("AUDIOCODEC", field_tokens.shape) + elif doc[f"{field}_type"] == 'REFSPEAKERCODEC': + reference_codec_paths = field_data.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + if self.codec_folder is not None: + reference_codec_path = self.codec_folder / reference_codec_path + field_tokens = torch.load(reference_codec_path).long() + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + _min_len = int(self.context_duration_min * self.codebook_fps) + _max_len = int(self.context_duration_max * self.codebook_fps) + reference_codec_len = rng.randint(_min_len, _max_len) + reference_codec_len = min(reference_codec_len, field_tokens.shape[1]) + si = rng.randint(0, field_tokens.shape[1] - reference_codec_len) + field_tokens = field_tokens[:, si : si + reference_codec_len] + if self.context_pattern == "delay_parallel": + field_tokens = torch.cat( + [ + torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long(), + field_tokens, + torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long(), + ], + dim=1, + ) + new_field_tokens = [] + for _c in range(self.num_speech_codebooks): + st = self.num_speech_codebooks - _c + et = field_tokens.shape[1] - _c + new_field_tokens.append(field_tokens[_c, st:et]) + field_tokens = torch.stack(new_field_tokens, dim=0) + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'DUMMYCONTEXT': + field_tokens = torch.zeros(self.num_speech_codebooks, 1).long() + return [field_tokens] + elif doc[f"{field}_type"] == 'CONTEXTANSWER': + # Both Context and Answer are in the field + context_info, answer_codec_path = field_data.split(";") + if self.codec_folder is not None: + context_codec_path = self.codec_folder / context_info + answer_codec_path = self.codec_folder / answer_codec_path + if context_info.startswith("TEXT CONTEXT:"): + context_tokens = self._get_text_tokens(context_info.strip(" ")) + # pad field tokens to fixed length + assert ( + self.context_duration_min == self.context_duration_max + ), "TEXT CONTEXT only supports fixed context duration" + _fixed_context_len = int(self.context_duration_min * self.codebook_fps) + context_tokens = context_tokens + [self.tokenizer.pad_id] * (_fixed_context_len - len(context_tokens)) + + answer_tokens = torch.load(answer_codec_path).long() + answer_tokens[0] = (answer_tokens[0] + self.speech_offset).long() + field_tokens = context_tokens + [self.tokenizer.pad_id] + [answer_tokens] + else: + context_tokens = torch.load(context_codec_path).long() + context_tokens[0] = (context_tokens[0] + self.speech_offset).long() + assert ( + self.context_duration_min == self.context_duration_max + ), "CONTEXTANSWER only supports fixed context duration" + reference_codec_len = int(self.context_duration_min * self.codebook_fps) + if context_tokens.shape[1] < reference_codec_len: + # Repeat the context to match the reference_codec_len + context_tokens = torch.cat( + [context_tokens] * (reference_codec_len // context_tokens.shape[1] + 1), dim=1 + ) + assert ( + context_tokens.shape[1] >= reference_codec_len + ), "CONTEXTANSWER context duration is less than min duration {} {} {}".format( + context_tokens.shape[1], reference_codec_len, context_codec_path + ) + si = rng.randint(0, context_tokens.shape[1] - reference_codec_len) + context_tokens = context_tokens[:, si : si + reference_codec_len] + + answer_tokens = torch.load(answer_codec_path).long() + answer_tokens[0] = (answer_tokens[0] + self.speech_offset).long() + pad_tokens = torch.zeros(self.num_speech_codebooks, 1).long() + # padding between context and answer + field_tokens = torch.cat([context_tokens, pad_tokens, answer_tokens], dim=1) + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'SEPARATIONCODECS': + mixed_codec_path, reference_codec_paths = field_data.split(",") + reference_codec_paths = reference_codec_paths.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + mixed_codec = torch.load(mixed_codec_path).long() + reference_codec = torch.load(reference_codec_path).long() + reference_codec_len = rng.randint(240, 400) + reference_codec = reference_codec[:, :reference_codec_len] + # MIXED AUDIO AND REF AUDIO ARE SEPARATED BY 8 TIMESTEPS OF 1023 TOKENS IN ALL CODEBOOKS + mask_tokens = (torch.ones(self.num_speech_codebooks, self.num_speech_codebooks) * 1023).long() + field_tokens = torch.cat([mixed_codec, mask_tokens, reference_codec], dim=1) + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'EDITINGCODECS': + reference_audio_path = field_data + reference_codec = torch.load(reference_audio_path).long() + assert reference_codec.shape[1] > 80 # ensure reference audio is atleast 1 second + mask_len = rng.randint(40, 320) # ~0.5 second to 4 seconds + mask_len = min(mask_len, reference_codec.shape[1] - 80) + mask_start = rng.randint(0, reference_codec.shape[1] - mask_len) + mask_end = mask_start + mask_len + mask_tokens = (torch.ones(self.num_speech_codebooks, self.num_speech_codebooks) * 1023).long() + seg1 = reference_codec[:, :mask_start] + seg2 = reference_codec[:, mask_end:] + field_tokens = torch.cat([seg1, mask_tokens, seg2], dim=1) + # MISSING AUDIO IS REPLACED WITH 8 TIMESTEPS OF 1023 TOKENS IN ALL CODEBOOKS + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + else: + raise Exception(f"{field}_type not recognized") + return field_tokens + + def _insert_data_in_template(self, prompt_template_fields, doc, answer_field): + """Format the input example according to the template""" + out_dict = {} + for field in prompt_template_fields: + # discard the last one, {label} / {answer} + # Or if some fields from the template aren't present, e.g. {answer} during inference + # just remove that field from the template, leaving the space blank + if field == answer_field or field not in doc.keys(): + continue + # out_dict[field] = "" + + elif field in doc.keys(): + field_data = doc[field] + if f"{field}_type" not in doc.keys(): + doc[f"{field}_type"] = "TEXT" + raise Exception(f"{field}_type does not exist in doc") + else: + out_dict[field] = self._get_tokens(doc, field, field_data) + return out_dict + + def get_position_ids(self, virtual_token, context_and_qquestion): + enc_input = [] + enc_input.append(virtual_token) + if context_and_qquestion.dim() > 2: + enc_input.append(context_and_qquestion[:, 0, :]) + else: + enc_input.append(context_and_qquestion) + + enc_input = torch.cat(enc_input, dim=1) + + enc_input_p = enc_input[:, 0, :] if enc_input.dim() == 3 else enc_input + return build_position_ids(enc_input_p).contiguous() + + def collate_fn(self, batch): + """Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch""" + + data_dict = self.pad_batch_and_build_loss_mask(batch) + + if self.encoder_type == "multi_transformer": + position_ids = [ + self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][0]), + self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][1]), + ] + else: + position_ids = self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens']) + + return ( + data_dict['virtual_tokens'], + data_dict['context_and_question_tokens'], + data_dict['enc_mask'], + data_dict['dec_input'], + data_dict['dec_input_mask'], + data_dict['dec_labels'], + data_dict['dec_labels_mask'], + position_ids, + data_dict['taskname_id'], + data_dict['speech_mask'], + data_dict['context_and_question_tokens_lens'], + data_dict['cross_attention_prior'], + data_dict['text_limits'], + data_dict['lang'], + data_dict['question_texts'], + ) + + def pad_batch_and_build_loss_mask(self, batch): + """Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask""" + ( + taskname_ids, + _, + virtual_tokens_len, + _, + _, + context_and_question_tokens_len, + _, + dec_input_len, + _, + dec_labels_len, + _, + _, + _, + question_texts, + ) = zip(*batch) + + taskname_ids = self.pad_taskname_ids(taskname_ids) + + max_virtual_tokens_len = max(virtual_tokens_len).item() if virtual_tokens_len is not None else 0 + if isinstance(virtual_tokens_len, tuple): + virtual_tokens_len = torch.stack(virtual_tokens_len) + virtual_mask = get_mask_from_lengths(virtual_tokens_len) + + if self.encoder_type == "multi_transformer": + max_context_len = ( + max(_c[0] for _c in context_and_question_tokens_len) + if context_and_question_tokens_len is not None + else 0 + ) + max_question_len = ( + max(_c[1] for _c in context_and_question_tokens_len) + if context_and_question_tokens_len is not None + else 0 + ) + max_context_and_question_tokens_len = [max_context_len, max_question_len] + context_len = torch.stack([_c[0] for _c in context_and_question_tokens_len]) + question_len = torch.stack([_c[1] for _c in context_and_question_tokens_len]) + context_mask = get_mask_from_lengths(context_len) + question_mask = get_mask_from_lengths(question_len) + context_and_question_tokens_len = [context_len, question_len] + context_and_question_mask = [context_mask, question_mask] + enc_mask = [ + torch.cat([virtual_mask, context_and_question_mask[0]], dim=1), + torch.cat([virtual_mask, context_and_question_mask[1]], dim=1), + ] + # import ipdb; ipdb.set_trace() + else: + max_context_and_question_tokens_len = ( + max(context_and_question_tokens_len).item() if context_and_question_tokens_len is not None else 0 + ) + if isinstance(context_and_question_tokens_len, tuple): + context_and_question_tokens_len = torch.stack(context_and_question_tokens_len) + context_and_question_mask = get_mask_from_lengths(context_and_question_tokens_len) + enc_mask = torch.cat([virtual_mask, context_and_question_mask], dim=1) + + max_dec_input_len = max(dec_input_len).item() if dec_input_len is not None else 0 + max_dec_labels_len = max(dec_labels_len).item() if dec_labels_len is not None else 0 + + ( + virtual_tokens_list, + context_question_tokens_list, + dec_input_list, + dec_input_mask_list, + dec_labels_list, + dec_labels_mask_list, + speech_mask_list, + cross_attention_prior_list, + text_limits, + lang_list, + ) = ( + [], + [], + [], + [], + [], + [], + [], + [], + [], + [], + ) + + for i, sample_tuple in enumerate(batch): + ( + _, + virtual_token, + virtual_token_len, + context_token_len, + context_and_question_token, + context_and_question_token_len, + dec_input, + dec_input_len, + dec_label, + dec_label_len, + is_speech, + cross_attention_prior, + lang, + _, + ) = sample_tuple + + virtual_tokens_list.append( + general_padding( + virtual_token, virtual_token_len.item(), max_virtual_tokens_len, pad_value=self.tokenizer.pad_id + ) + ) + + if self.encoder_type == "multi_transformer": + context_tokens_padded = general_padding( + context_and_question_token[0], + context_and_question_token_len[0].item(), + max_context_and_question_tokens_len[0], + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims( + context_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + question_tokens_padded = general_padding( + context_and_question_token[1], + context_and_question_token_len[1].item(), + max_context_and_question_tokens_len[1], + pad_value=self.tokenizer.pad_id, + ) + if len(question_tokens_padded.shape) < 2: + question_tokens_padded = pad_text_to_speech_dims( + question_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + context_question_tokens_list.append([context_tokens_padded, question_tokens_padded]) + else: + # This means context and questions are concatenated together + context_tokens_padded = general_padding( + context_and_question_token, + context_and_question_token_len.item(), + max_context_and_question_tokens_len, + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims( + context_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + context_question_tokens_list.append(context_tokens_padded) + + if max_dec_input_len > 0: + dec_input_padded = general_padding( + dec_input, dec_input_len.item(), max_dec_input_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_input_padded.shape) < 2: + dec_input_padded = pad_text_to_speech_dims( + dec_input_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + dec_input_list.append(dec_input_padded) + dec_mask = ( + torch.as_tensor(([1] * dec_input_len) + ([0] * (max_dec_input_len - dec_input_len))) + .long() + .contiguous() + ) + dec_input_mask_list.append(dec_mask) + speech_mask = dec_mask if is_speech else torch.zeros(dec_mask.shape) + speech_mask_list.append(speech_mask) + + if max_dec_labels_len > 0: + loss_mask = ( + torch.as_tensor(([1] * dec_label_len) + ([0] * (max_dec_labels_len - dec_label_len))) + .long() + .contiguous() + ) + dec_label_padded = general_padding( + dec_label, dec_label_len.item(), max_dec_labels_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_label_padded.shape) < 2: + dec_label_padded = pad_text_to_speech_dims( + dec_label_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + dec_labels_list.append(dec_label_padded) + dec_labels_mask_list.append(loss_mask) + + _p0 = max_dec_labels_len - dec_label_len + if self.encoder_type == "multi_transformer": + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len[1] + - context_and_question_token_len[1] + - virtual_token_len + ) + else: + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len + - context_and_question_token_len + - virtual_token_len + ) + + cross_attention_prior_padded = torch.nn.functional.pad( + cross_attention_prior, + pad=(0, _p1, 0, _p0), + mode="constant", + value=1, + ) + cross_attention_prior_list.append(cross_attention_prior_padded) + + if self.encoder_type == "multi_transformer": + _start_of_text_id = virtual_token_len + 4 + _end_of_text_id = _start_of_text_id + ( + context_and_question_token_len[1] - 2 - 4 + ) # -2 for some end tokens + else: + _start_of_text_id = virtual_token_len + context_token_len + 4 + _end_of_text_id = _start_of_text_id + ( + context_and_question_token_len - context_token_len - 2 - 4 + ) # -2 for some end tokens + text_limits.append(torch.tensor([_start_of_text_id.item(), _end_of_text_id.item()])) + lang_list.append(torch.tensor(lang)) + + dec_labels_mask = torch.stack(dec_labels_mask_list) if len(dec_labels_mask_list) > 0 else None + if dec_labels_mask is not None and self.context_conditioning == 'decoder': + # Mask out context tokens from loss computation. +1 for bos/pad in the beginning + dec_labels_mask[:, : self.decoder_context_len + 1] = 0 + + if self.encoder_type == "multi_transformer": + context_batch = torch.stack([c[0] for c in context_question_tokens_list]) + question_batch = torch.stack([c[1] for c in context_question_tokens_list]) + context_and_question_tokens = [context_batch, question_batch] + else: + context_and_question_tokens = torch.stack(context_question_tokens_list) + + data_dict = { + "taskname_id": taskname_ids, + "virtual_tokens": torch.stack(virtual_tokens_list), + "context_and_question_tokens": context_and_question_tokens, + "enc_mask": enc_mask, + "dec_input": torch.stack(dec_input_list) if len(dec_input_list) > 0 else None, + "dec_input_mask": torch.stack(dec_input_mask_list) if len(dec_input_mask_list) > 0 else None, + "dec_labels": torch.stack(dec_labels_list) if len(dec_labels_list) > 0 else None, + "dec_labels_mask": dec_labels_mask, + "speech_mask": torch.stack(speech_mask_list) if len(speech_mask_list) > 0 else None, + "context_and_question_tokens_lens": context_and_question_tokens_len, + "cross_attention_prior": ( + torch.stack(cross_attention_prior_list) if len(cross_attention_prior_list) > 0 else None + ), + "text_limits": ( + torch.stack(text_limits) if len(text_limits) > 0 else None + ), # tensor, valid range of answer transcripts without virtual/instruction/end tokens. + "lang": torch.stack(lang_list), + "question_texts": question_texts, + } + + return data_dict diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py new file mode 100644 index 000000000000..9b0a4f8d06c2 --- /dev/null +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py @@ -0,0 +1,986 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import random +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch +import webdataset as wd +from omegaconf import OmegaConf + +from nemo.collections.asr.data.audio_to_text import ( + _speech_collate_fn, + cache_datastore_manifests, + expand_sharded_filepaths, + shard_manifests_if_needed, +) +from nemo.collections.common.parts.preprocessing import collections +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import T5Sentinel +from nemo.collections.nlp.modules.common import VirtualPromptSource +from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.collections.tts.parts.utils.tts_dataset_utils import beta_binomial_prior_distribution, general_padding +from nemo.core.classes import IterableDataset +from nemo.utils import logging + +__all__ = ['T5SpeechLMTarredDataset'] + + +@dataclass +class G2PConfig: + _target_: str = "nemo.collections.tts.g2p.models.en_us_arpabet.EnglishG2p" + phoneme_dict: str = "scripts/tts_dataset_files/cmudict-0.7b_nv22.10" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + + +@dataclass +class TextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.EnglishPhonemesTokenizer" + punct: bool = True + stresses: bool = True + chars: bool = True + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: G2PConfig = G2PConfig() + + +@dataclass +class TextTokenizerConfig: + text_tokenizer: TextTokenizer = TextTokenizer() + + +def _get_default_text_tokenizer_conf(): + text_tokenizer: TextTokenizerConfig = TextTokenizerConfig() + return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer)) + + +def pad_text_to_speech_dims(text_tensor, pad_id): + token_len = text_tensor.shape[0] + empty_padding = torch.ones((7, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id + return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0) + + +class InstructionTuningManifestProcessor: + """ + Class that processes a manifest json file containing paths to audio files, transcripts, and durations (in seconds). + Each new line is a different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + Args: + manifest_filepath: Path to manifest json as described above. Can be comma-separated paths. + parser: Str for a language specific preprocessor or a callable. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + bos_id: Id of beginning of sequence symbol to append if not None. + eos_id: Id of end of sequence symbol to append if not None. + pad_id: Id of pad symbol. Defaults to 0. + """ + + def __init__( + self, + manifest_filepath: str, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + max_utts: int = 0, + index_by_file_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + + # ASRAudioText( + self.collection = collections.InstructionTuningAudioText( + manifests_files=manifest_filepath, + min_duration=min_duration, + max_duration=max_duration, + max_seq_length=max_seq_length, + max_number=max_utts, + index_by_file_id=index_by_file_id, + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + +class _TarredInstructionTuningDataset(IterableDataset): + """ + A similar Dataset to the AudioToCharDataset/AudioToBPEDataset, but which loads tarred audio files. + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + sample_rate: int, + shuffle_n: int = 0, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + return_sample_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + self.shard_manifests = shard_manifests + + # Shard manifests if necessary and possible and then expand the paths + manifest_filepath = shard_manifests_if_needed( + shard_manifests=shard_manifests, + shard_strategy=shard_strategy, + manifest_filepaths=manifest_filepath, + world_size=world_size, + global_rank=global_rank, + ) + + # If necessary, cache manifests from object store + cache_datastore_manifests(manifest_filepaths=manifest_filepath) + + self.manifest_processor = InstructionTuningManifestProcessor( + manifest_filepath=manifest_filepath, + max_duration=max_duration, + min_duration=min_duration, + max_seq_length=max_seq_length, + max_utts=0, + index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + self.len = self._compute_len() + self.return_sample_id = return_sample_id + + audio_tar_filepaths = expand_sharded_filepaths( + sharded_filepaths=audio_tar_filepaths, + shard_strategy=shard_strategy, + world_size=world_size, + global_rank=global_rank, + ) + + if shuffle_n > 0: + # Only shuffle training data tar files + logging.info("Shuffling Tar files") + custom_rng = random.Random() + custom_rng.shuffle(audio_tar_filepaths) + logging.info("Done shuffling Tar files") + logging.info(audio_tar_filepaths[:10]) + + self.sample_rate = sample_rate + + # Put together WebDataset + self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) + + if shuffle_n > 0: + self._dataset = self._dataset.shuffle(shuffle_n) + else: + logging.info("WebDataset will not shuffle files within the tar files.") + + self._dataset = ( + self._dataset.rename(key='__key__', answer='pt', context='context.pt') + .to_tuple('key', 'answer', 'context') + .pipe(self._filter) + .pipe(self._loop_offsets) + .map(f=self._build_sample) + ) + + def _filter(self, iterator): + """This function is used to remove samples that have been filtered out by ASRAudioText already. + Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample + that was filtered out (e.g. for duration). + Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard, + which may make your code hang as one process will finish before the other. + """ + + class TarredAudioFilter: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + + def __iter__(self): + return self + + def __next__(self): + while True: + audio_filename, answer_bytes, context_bytes = next(self.iterator) + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + if file_id in self.collection.mapping: + return audio_filename, answer_bytes, context_bytes + + return TarredAudioFilter(self.manifest_processor.collection) + + def _loop_offsets(self, iterator): + """This function is used to iterate through utterances with different offsets for each file.""" + + class TarredAudioLoopOffsets: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + self.current_fn = None + self.current_bytes = None + self.current_context_bytes = None + self.offset_id = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.current_fn is None: + self.current_fn, self.current_bytes, self.current_context_bytes = next(self.iterator) + self.offset_id = 0 + else: + offset_list = self.collection.mapping[self.current_fn] + if len(offset_list) == self.offset_id + 1: + self.current_fn, self.current_bytes, self.current_context_bytes = next(self.iterator) + self.offset_id = 0 + else: + self.offset_id += 1 + + return self.current_fn, self.current_bytes, self.current_context_bytes, self.offset_id + + return TarredAudioLoopOffsets(self.manifest_processor.collection) + + def _collate_fn(self, batch): + return _speech_collate_fn(batch) + + def _build_sample(self, tup): + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" + audio_filename, encodec, ref_encodec, offset_id = tup + return audio_filename, encodec, ref_encodec, offset_id + + def get_manifest_sample(self, sample_id): + return self.manifest_processor.collection[sample_id] + + def __iter__(self): + return self._dataset.__iter__() + + def _compute_len(self): + if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized(): + my_len = torch.tensor(len(self.manifest_processor.collection), dtype=torch.int32).cuda() + torch.distributed.all_reduce(my_len) + my_len = my_len.int() + logging.info(f'Sharded manifests: Total length: {my_len}') + else: + my_len = len(self.manifest_processor.collection) + + return my_len + + def __len__(self): + return self.len + + +class T5SpeechLMTarredDataset(_TarredInstructionTuningDataset): + """ + The dataset class for prompt-tuning or p-tuning pretrained T5 SpeechLM models. + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + tokenizer, + virtual_prompt_source: VirtualPromptSource, + task_templates: dict, + pseudo_tokens, + pad_token_id: str, + max_seq_length: int, + sample_rate: int, + shuffle_n: int = 0, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + for_train: bool = True, + decoder_starts_with_pad: bool = False, + add_eos_to_decoder_output: bool = True, + add_sentinel_to_input: bool = True, + ul2_prompt_token: str = None, + segment_max_duration: Optional[int] = None, + trim: bool = False, + trim_ref: Optional[float] = None, + trim_top_db: Optional[int] = None, + trim_frame_length: Optional[int] = None, + trim_hop_length: Optional[int] = None, + pad_multiple: int = 1, + pitch_augment: bool = False, + speech_offset: Optional[int] = None, + train_task: Optional[str] = None, + seq_pattern: Optional[str] = "parallel", + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + return_sample_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: Optional[bool] = False, + lm_vocab_size: Optional[int] = None, + use_attention_prior: Optional[bool] = False, + attention_prior_scaling_factor: Optional[float] = 1.0, + cross_attention_epsilon: Optional[float] = 0.0, + num_speech_codebooks: Optional[int] = 8, + **kwargs, + ): + """ + Only speech parameters are explained here. + segment_max_duration: Optional[int] = None, - Speech max segment duration + trim: bool = False, - speech parameter + trim_ref: Optional[float] = None, - speech parameter + trim_top_db: Optional[int] = None, - speech parameter + trim_frame_length: Optional[int] = None, - speech parameter + trim_hop_length: Optional[int] = None, - speech parameter + pad_multiple: int = 1, - speech parameter + pitch_augment: bool = False, - speech parameter + speech_offset: Optional[int] = None, - if speech tokens then add this offset to the token indices to distinguish between text and speech tokens. + **kwargs, + """ + # These two variables need to be set before calling super().__init__() because the parent class calls `load_data()` which requires these attributes. + self.decoder_starts_with_pad = decoder_starts_with_pad + self.add_eos_to_decoder_output = add_eos_to_decoder_output + self.add_sentinel_to_input = add_sentinel_to_input + self.ul2_prompt_token = ul2_prompt_token + # Speech related variables + # self.encodec_model = EncodecModel.encodec_model_24khz() + # self.encodec_model.set_target_bandwidth(6.0) + self.base_data_dir = None + self.segment_max_duration = segment_max_duration + self.sample_rate = sample_rate + # self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate) + self.pad_multiple = pad_multiple + self.pitch_augment = pitch_augment + self.trim = trim + self.trim_ref = trim_ref if trim_ref is not None else np.max + self.trim_top_db = trim_top_db if trim_top_db is not None else 60 + self.trim_frame_length = trim_frame_length if trim_frame_length is not None else 2048 + self.trim_hop_length = trim_hop_length if trim_hop_length is not None else 512 + self.speech_offset = speech_offset if speech_offset is not None else 3 + self.seq_pattern = seq_pattern + self.min_duration = kwargs.get('min_duration', 0.1) + self.max_duration = kwargs.get('max_duration', 20) + self.use_attention_prior = use_attention_prior + self.attention_prior_scaling_factor = attention_prior_scaling_factor + self.cross_attention_epsilon = cross_attention_epsilon # value of prior for context tokens (b/w 0 and 1) + assert self.cross_attention_epsilon >= 0.0 and self.cross_attention_epsilon <= 1.0 + + self.train_task = train_task + + # Initialized super part + self.tokenizer = tokenizer + self.virtual_prompt_source = virtual_prompt_source + self.task_templates = task_templates + self.pseudo_tokens = pseudo_tokens + self.pseudo_token_ids = set(self.tokenizer.tokens_to_ids(self.pseudo_tokens)) + self.pad_token_id = pad_token_id + self.max_seq_length = max_seq_length + self.min_seq_length = min_seq_length + self.add_bos = add_bos + self.add_eos = add_eos + self.for_train = for_train + self.use_phoneme_tokenizer = use_phoneme_tokenizer + self.examples = [] + self.lm_vocab_size = tokenizer.vocab_size if lm_vocab_size is None else lm_vocab_size + self.num_speech_codebooks = num_speech_codebooks + + assert self.min_seq_length <= max_seq_length, "Min sequence length should be less than or equal to max" + assert self.max_seq_length > 0, "Max sequence length should be greater than 0" + + self.context_length = kwargs.pop('context_length', None) # only used in gpt dataset atm + + logging.info("Loading and tokenizing dataset ... ") + + super().__init__( + audio_tar_filepaths=audio_tar_filepaths, + manifest_filepath=manifest_filepath, + sample_rate=sample_rate, + shuffle_n=shuffle_n, + min_duration=self.min_duration, + max_duration=self.max_duration, + max_seq_length=max_seq_length, + shard_strategy=shard_strategy, + shard_manifests=shard_manifests, + global_rank=global_rank, + world_size=world_size, + return_sample_id=return_sample_id, + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + self.encodec, self.ref_encodec = None, None + + def _insert_virtual_token_placeholders(self, input_example, virtual_token_splits): + """Insert the correct number of pseudo tokens at the <|VIRTUAL_PROMPT_n|> markers""" + total_inserted_tokens = 0 + + for idx in range(len(virtual_token_splits)): + split_start = total_inserted_tokens + split_end = total_inserted_tokens + virtual_token_splits[idx] + pseudo_tokens_for_split = "".join(self.pseudo_tokens[split_start:split_end]) + input_example = input_example.replace(f'<|VIRTUAL_PROMPT_{idx}|>', pseudo_tokens_for_split) + total_inserted_tokens = split_end + + return input_example + + def pad_taskname_ids(self, taskname_ids): + # Pad taskname_ids to be the same length for the prompt encoder + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + max_taskname_length = max(len(ids) for ids in taskname_ids) + taskname_ids = [ids + [self.pad_token_id] * (max_taskname_length - len(ids)) for ids in taskname_ids] + taskname_ids = torch.tensor(taskname_ids) + + # Task ids are just used for a look up embeddings for prompt-table + elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT: + taskname_ids = torch.tensor(taskname_ids) + + return taskname_ids + + def _build_sample(self, tup): + audio_filename, self.encodec, self.ref_encodec, offset_id = tup + + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id] + manifest_entry = self.manifest_processor.collection[manifest_idx] + doc = {} + doc['context'] = manifest_entry.context + doc['context_type'] = manifest_entry.context_type + doc['context_duration'] = manifest_entry.context_duration + doc['answer'] = manifest_entry.answer + doc['answer_type'] = manifest_entry.answer_type + doc['answer_duration'] = manifest_entry.answer_duration + doc['question'] = manifest_entry.question + doc['question_type'] = manifest_entry.question_type + + taskname = "squad" + prompt_template = self.task_templates[taskname]["prompt_template"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] + answer_field = self.task_templates[taskname]["answer_field"] + + input_example = prompt_template + + question_in_manifest = manifest_entry.question + + # Format the input example according to the template + # Get context, question and answer codes in a dict. + input_dict = self._insert_data_in_template(input_example, prompt_template_fields, doc, answer_field) + context_tokens = input_dict['context'] + question_tokens = input_dict['question'] + + # Logic to prune context + # In case of TTS task, the entire reference speech is not required, so we randomly select a portion + # of the reference audio. + # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be + # predicted by the decoder. + start_token_index = 0 + end_token_index = -1 + if "Text to speech this" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + reduced_len = min( + 400, + ( + int(total_context_len * 0.2) + if total_context_len > 600 + else int(total_context_len * random.uniform(0.2, 0.5)) + ), + ) + start_token_index = random.randint( + 0, total_context_len - reduced_len + ) # start index can be greater than 440 + context_tokens[0] = context_tokens[0][ + :, start_token_index : min(start_token_index + 440, start_token_index + reduced_len) + ] + elif "Next token prediction" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + end_token_index = int(total_context_len * random.uniform(0.01, 0.2)) + context_tokens[0] = context_tokens[0][:, :end_token_index] + + # Get virtual tokens + virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) + + # a trick to align with the data format in t5 pretraining + # new + virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) + if self.add_sentinel_to_input: + question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + + # Add BOS/EOS to the input of encoder if desired, adds EOS by default + if self.ul2_prompt_token is not None: + ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) + assert len(ul2_prompt_token_id) == 1 + context_tokens = ul2_prompt_token_id + context_tokens + if self.add_bos: + context_tokens = [self.tokenizer.bos_id] + context_tokens + if self.add_eos: + question_tokens = question_tokens + [self.tokenizer.eos_id] + + # Try to truncate input text to fit into the max sequence length + if self._get_len(context_tokens, question_tokens, virtual_tokens) > self.max_seq_length: + context_tokens, question_tokens, virtual_tokens = self._truncate_input_speech( + context_tokens, question_tokens, virtual_tokens + ) + + virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) + context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) + question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) + + if doc["question_type"] != "SPEECH" and doc["context_type"] == "SPEECH": + question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id) + if doc["context_type"] != "SPEECH" and doc["question_type"] == "SPEECH": + context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id) + context_tokens = context_tokens.to(question_tokens.device) + context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1) + + # get answer ids + if answer_field in doc.keys(): # training and validation + answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) + if end_token_index > -1: + answer_ids[0] = answer_ids[0][:, end_token_index:] + + if self.decoder_starts_with_pad: + answer_text_ids = [self.tokenizer.pad_id] + else: + answer_text_ids = [self.tokenizer.bos_id] + + answer_text_ids += answer_ids + + if self.add_eos_to_decoder_output: + answer_text_ids += [self.tokenizer.eos_id] + else: + answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) + + # Skip example if the final length doesn't fit length requirements even after truncation + if ( + self.min_seq_length + <= self._get_element_len(context_and_question_tokens) + self._get_element_len(virtual_tokens) + <= self.max_seq_length + and self.min_seq_length <= self._get_element_len(answer_text_ids) <= self.max_seq_length + ): + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + taskname_id = self.tokenizer.text_to_ids(taskname) + elif ( + self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT + ): # TODO (@adithyare) this class and GPTPromptLearningDataset should be merged. + taskname_id = -1 + else: + raise ValueError("Invalid virtual prompt source specified") + + dec_input = None + dec_labels = None + + if answer_field in doc.keys(): # training and validation + dec_input = answer_text_ids[:-1] + dec_labels = answer_text_ids[1:] + + dec_input, dec_input_len = self.list_to_tensor(dec_input, True) + dec_labels, dec_labels_len = self.list_to_tensor(dec_labels, True) + is_speech = True if doc["answer_type"] == "SPEECH" else False + if is_speech: + assert dec_input.dim() == 2 and dec_labels.dim() == 2 + if self.seq_pattern == "delay_parallel": + num_codebooks = dec_input.shape[0] + dec_input_padded = torch.cat( + [ + torch.zeros_like(dec_input[:, 0:num_codebooks]), + dec_input, + torch.zeros_like(dec_input[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_labels_padded = torch.cat( + [ + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + dec_labels, + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_input_new = [] + dec_labels_new = [] + for _c in range(self.num_speech_codebooks): + st = num_codebooks - _c + et_decoder_input = dec_input_padded.shape[1] - _c + et_decoder_labels = dec_labels_padded.shape[1] - _c + dec_input_new.append(dec_input_padded[_c, st:et_decoder_input]) + dec_labels_new.append(dec_labels_padded[_c, st:et_decoder_labels]) + dec_input = torch.stack(dec_input_new, dim=0) + dec_labels = torch.stack(dec_labels_new, dim=0) + dec_input_len = torch.tensor(dec_input.shape[1]).long() + dec_labels_len = torch.tensor(dec_labels.shape[1]).long() + + enc_len = context_tokens_len + question_tokens_len + virtual_tokens_len + # TODO: Remove hardcoding + num_question_offset = 4 # For "Text to Speech this" + + cross_attention_prior = torch.zeros(dec_labels_len, enc_len) + self.cross_attention_epsilon + if self.use_attention_prior: + cross_attention_question_prior = torch.from_numpy( + beta_binomial_prior_distribution( + question_tokens_len.item() - num_question_offset, + dec_labels_len.item(), + scaling_factor=self.attention_prior_scaling_factor, + ) + ) + cross_attention_prior[:, virtual_tokens_len + context_tokens_len + num_question_offset :] = ( + cross_attention_question_prior + ) + + return ( + taskname_id, + virtual_tokens, + virtual_tokens_len, + context_and_question_tokens, + context_tokens_len + question_tokens_len, + dec_input, + dec_input_len, + dec_labels, + dec_labels_len, + is_speech, + cross_attention_prior, + ) + else: + return None + + def _truncate_input_speech(self, context_tokens, question_tokens, virtual_tokens): + total_len = self._get_len(context_tokens, question_tokens, virtual_tokens) + context_len = self._get_element_len(context_tokens) + truncation_length = total_len - self.max_seq_length + 1 + context_tokens[0] = context_tokens[0][:, min(truncation_length, context_len) :] + return context_tokens, question_tokens, virtual_tokens + + def list_to_tensor(self, element, fill=False): + """ + Convert list to tensor. The list might contain integers, 2D-tensors (speech tokens) and combination of two. + If all of them are ints, simply convert to tensor + If combination of 2D-tensor and ints. Convert int to the dimension of the tensor. + example: [2, 4, 5] -> torch.tensor([2, 4, 5]) + example: [2, torch.tensor([[4, 5, 6], [6, 7, 8]])] -> torch.tensor( [[-1, 4, 5, 6], [2, 6, 7, 8]] ) + """ + ret, ln = None, None + if element is None: + return ret, ln + + max_len = max([1 if isinstance(item, int) else len(item) for item in element]) + if max_len == 1: + ret = torch.as_tensor(element).long() + ln = torch.tensor(ret.size()[0]).long() + else: + ret = [] + for e in element: + if isinstance(e, int): + tmp = torch.full((8, 1), e if fill else -1) + tmp[7] = e + else: + tmp = e + ret.append(tmp) + ret = torch.cat(ret, dim=1) + ln = torch.tensor(ret.size()[1]).long() + return ret, ln + + def _get_text_tokens(self, text): + input_ids = self.tokenizer.text_to_ids(text) + return input_ids + + def _get_phoneme_tokens(self, text): + input_ids = phoneme_tokenizer.encode(text) + input_ids_adjusted = [_id + self.lm_vocab_size for _id in input_ids] + return input_ids_adjusted + + def _pad_wav_to_multiple(self, wav): + if self.pad_multiple > 1: + if wav.shape[0] % self.pad_multiple != 0: + wav = torch.cat( + [wav, torch.zeros(self.pad_multiple - wav.shape[0] % self.pad_multiple, dtype=torch.float)] + ) + return wav + + def _get_element_len(self, element): + length = 0 + if isinstance(element, list): + for e in element: + if isinstance(e, int): + length += 1 + else: + if e.dim() > 1: + length += e.size()[1] + else: + length += e.size()[0] + else: + if element.dim() > 1: + length += element.size()[1] + else: + length += element.size()[0] + return length + + def _get_len(self, context_tokens, question_tokens, virtual_tokens): + length = 0 + length += self._get_element_len(context_tokens) + length += self._get_element_len(question_tokens) + length += self._get_element_len(virtual_tokens) + return length + + def _get_speech_tokens(self, field): + + # Convert to codes + codec_codes, codec_codes_length = None, None # Codes + + if self.train_task == 'tts': + if field == 'context': + self.ref_encodec = torch.load(io.BytesIO(self.ref_encodec), map_location="cpu").long() + codec_codes = self.ref_encodec + elif field == 'answer': + self.encodec = torch.load(io.BytesIO(self.encodec), map_location="cpu").long() + codec_codes = self.encodec + elif self.train_task == 'asr': + if field == 'context': + self.ref_encodec = torch.load(io.BytesIO(self.ref_encodec), map_location="cpu").long() + codec_codes = self.ref_encodec + + # codec_codes_length = torch.tensor(codec_codes.shape[1]).long() + + # Convert codes to codes corresponding to megatron embedding layer + codec_codes[0] = (codec_codes[0] + self.speech_offset).long() + + return codec_codes + + def _get_tokens(self, doc, field, field_data): + if f"{field}_type" not in doc.keys(): + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'TEXT': + _text = field_data.strip(" ") + if self.use_phoneme_tokenizer: + instruction_tokens = self._get_text_tokens("Phoneme TTS") + field_tokens = self._get_phoneme_tokens(_text.replace("Text to speech this ", "")) + field_tokens = instruction_tokens + field_tokens + else: + field_tokens = self._get_text_tokens(_text) # list of ids + elif doc[f"{field}_type"] == 'SPEECH': + field_tokens = self._get_speech_tokens(field) # list of ids + if not isinstance(field_tokens, list): + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'TOKENS': + # Do nothing; already tokenized + field_tokens = field_data + else: + raise Exception(f"{field}_type not recognized") + return field_tokens + + def _insert_data_in_template(self, input_example, prompt_template_fields, doc, answer_field): + """Format the input example according to the template""" + out_dict = {} + for field in prompt_template_fields: + # discard the last one, {label} / {answer} + # Or if some fields from the template aren't present, e.g. {answer} during inference + # just remove that field from the template, leaving the space blank + if field == answer_field or field not in doc.keys(): + continue + # out_dict[field] = "" + + elif field in doc.keys(): + field_data = doc[field] + if f"{field}_type" not in doc.keys(): + doc[f"{field}_type"] = "TEXT" + raise Exception(f"{field}_type does not exist in doc") + else: + out_dict[field] = self._get_tokens(doc, field, field_data) + return out_dict + + def get_position_ids(self, virtual_token, context_and_qquestion): + enc_input = [] + enc_input.append(virtual_token) + if context_and_qquestion.dim() > 2: + enc_input.append(context_and_qquestion[:, 0, :]) + else: + enc_input.append(context_and_qquestion) + + enc_input = torch.cat(enc_input, dim=1) + + enc_input_p = enc_input[:, 0, :] if enc_input.dim() == 3 else enc_input + return build_position_ids(enc_input_p).contiguous() + + def collate_fn(self, batch): + """Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch""" + + data_dict = self.pad_batch_and_build_loss_mask(batch) + + position_ids = self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens']) + + return ( + data_dict['virtual_tokens'], + data_dict['context_and_question_tokens'], + data_dict['enc_mask'], + data_dict['dec_input'], + data_dict['dec_input_mask'], + data_dict['dec_labels'], + data_dict['dec_labels_mask'], + position_ids, + data_dict['taskname_id'], + data_dict['speech_mask'], + data_dict['context_and_question_tokens_lens'], + data_dict['cross_attention_prior'], + ) + + def pad_batch_and_build_loss_mask(self, batch): + """Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask""" + ( + taskname_ids, + _, + virtual_tokens_len, + _, + context_and_question_tokens_len, + _, + dec_input_len, + _, + dec_labels_len, + _, + _, + ) = zip(*batch) + + taskname_ids = self.pad_taskname_ids(taskname_ids) + + max_virtual_tokens_len = max(virtual_tokens_len).item() if virtual_tokens_len is not None else 0 + if isinstance(virtual_tokens_len, tuple): + virtual_tokens_len = torch.stack(virtual_tokens_len) + virtual_mask = get_mask_from_lengths(virtual_tokens_len) + + max_context_and_question_tokens_len = ( + max(context_and_question_tokens_len).item() if context_and_question_tokens_len is not None else 0 + ) + if isinstance(context_and_question_tokens_len, tuple): + context_and_question_tokens_len = torch.stack(context_and_question_tokens_len) + context_and_question_mask = get_mask_from_lengths(context_and_question_tokens_len) + + max_dec_input_len = max(dec_input_len).item() if dec_input_len is not None else 0 + max_dec_labels_len = max(dec_labels_len).item() if dec_labels_len is not None else 0 + enc_mask = torch.cat([virtual_mask, context_and_question_mask], dim=1) + + ( + virtual_tokens_list, + context_question_tokens_list, + dec_input_list, + dec_input_mask_list, + dec_labels_list, + dec_labels_mask_list, + speech_mask_list, + cross_attention_prior_list, + ) = ( + [], + [], + [], + [], + [], + [], + [], + [], + ) + + for i, sample_tuple in enumerate(batch): + ( + _, + virtual_token, + virtual_token_len, + context_and_question_token, + context_and_question_token_len, + dec_input, + dec_input_len, + dec_label, + dec_label_len, + is_speech, + cross_attention_prior, + ) = sample_tuple + + virtual_tokens_list.append( + general_padding( + virtual_token, virtual_token_len.item(), max_virtual_tokens_len, pad_value=self.tokenizer.pad_id + ) + ) + + context_tokens_padded = general_padding( + context_and_question_token, + context_and_question_token_len.item(), + max_context_and_question_tokens_len, + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims(context_tokens_padded, self.tokenizer.pad_id) + context_question_tokens_list.append(context_tokens_padded) + + if max_dec_input_len > 0: + dec_input_padded = general_padding( + dec_input, dec_input_len.item(), max_dec_input_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_input_padded.shape) < 2: + dec_input_padded = pad_text_to_speech_dims(dec_input_padded, self.tokenizer.pad_id) + dec_input_list.append(dec_input_padded) + dec_mask = ( + torch.as_tensor(([1] * dec_input_len) + ([0] * (max_dec_input_len - dec_input_len))) + .long() + .contiguous() + ) + dec_input_mask_list.append(dec_mask) + speech_mask = dec_mask if is_speech else torch.zeros(dec_mask.shape) + speech_mask_list.append(speech_mask) + + if max_dec_labels_len > 0: + loss_mask = ( + torch.as_tensor(([1] * dec_label_len) + ([0] * (max_dec_labels_len - dec_label_len))) + .long() + .contiguous() + ) + dec_label_padded = general_padding( + dec_label, dec_label_len.item(), max_dec_labels_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_label_padded.shape) < 2: + dec_label_padded = pad_text_to_speech_dims(dec_label_padded, self.tokenizer.pad_id) + dec_labels_list.append(dec_label_padded) + dec_labels_mask_list.append(loss_mask) + + _p0 = max_dec_labels_len - dec_label_len + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len + - context_and_question_token_len + - virtual_token_len + ) + + cross_attention_prior_padded = torch.nn.functional.pad( + cross_attention_prior, + pad=(0, _p1, 0, _p0), + mode="constant", + value=1, + ) + cross_attention_prior_list.append(cross_attention_prior_padded) + + data_dict = { + "taskname_id": taskname_ids, + "virtual_tokens": torch.stack(virtual_tokens_list), + "context_and_question_tokens": torch.stack(context_question_tokens_list), + "enc_mask": enc_mask, + "dec_input": torch.stack(dec_input_list) if len(dec_input_list) > 0 else None, + "dec_input_mask": torch.stack(dec_input_mask_list) if len(dec_input_mask_list) > 0 else None, + "dec_labels": torch.stack(dec_labels_list) if len(dec_labels_list) > 0 else None, + "dec_labels_mask": torch.stack(dec_labels_mask_list) if len(dec_labels_mask_list) > 0 else None, + "speech_mask": torch.stack(speech_mask_list) if len(speech_mask_list) > 0 else None, + "context_and_question_tokens_lens": context_and_question_tokens_len, + "cross_attention_prior": ( + torch.stack(cross_attention_prior_list) if len(cross_attention_prior_list) > 0 else None + ), + } + + return data_dict diff --git a/nemo/collections/tts/g2p/models/ctc.py b/nemo/collections/tts/g2p/models/ctc.py index 2e180e766211..1859b09594ff 100644 --- a/nemo/collections/tts/g2p/models/ctc.py +++ b/nemo/collections/tts/g2p/models/ctc.py @@ -19,8 +19,8 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from transformers import AutoConfig, AutoModel, AutoTokenizer from nemo.collections.tts.g2p.data.ctc import CTCG2PBPEDataset @@ -101,11 +101,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer) - self.wer = WER(decoding=self.decoding, use_cer=False, log_prediction=False, dist_sync_on_step=True,) - self.per = WER(decoding=self.decoding, use_cer=True, log_prediction=False, dist_sync_on_step=True,) + self.wer = WER( + decoding=self.decoding, + use_cer=False, + log_prediction=False, + dist_sync_on_step=True, + ) + self.per = WER( + decoding=self.decoding, + use_cer=True, + log_prediction=False, + dist_sync_on_step=True, + ) def setup_grapheme_tokenizer(self, cfg): - """ Initialized grapheme tokenizer """ + """Initialized grapheme tokenizer""" if self.mode == "byt5": # Load appropriate tokenizer from HuggingFace @@ -315,7 +325,10 @@ def _setup_infer_dataloader(self, cfg: DictConfig) -> 'torch.utils.data.DataLoad ) @torch.no_grad() - def _infer(self, config: DictConfig,) -> List[int]: + def _infer( + self, + config: DictConfig, + ) -> List[int]: """ Runs model inference. diff --git a/nemo/collections/tts/g2p/models/heteronym_classification.py b/nemo/collections/tts/g2p/models/heteronym_classification.py index 54b9a8b07413..47d08eb16e17 100644 --- a/nemo/collections/tts/g2p/models/heteronym_classification.py +++ b/nemo/collections/tts/g2p/models/heteronym_classification.py @@ -19,8 +19,8 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.metrics.classification_report import ClassificationReport @@ -113,9 +113,9 @@ def make_step(self, batch): def training_step(self, batch, batch_idx): """ - Lightning calls this inside the training loop with the data from the training dataloader - passed in as `batch`. - """ + Lightning calls this inside the training loop with the data from the training dataloader + passed in as `batch`. + """ loss, logits = self.make_step(batch) self.log('train_loss', loss) @@ -267,7 +267,11 @@ def disambiguate( item = {"text_graphemes": cur_sentence, "start_end": cur_start_ends, "heteronym_span": cur_heteronyms} f.write(json.dumps(item, ensure_ascii=False) + '\n') - all_preds = self._disambiguate(manifest=tmp_manifest, batch_size=batch_size, num_workers=num_workers,) + all_preds = self._disambiguate( + manifest=tmp_manifest, + batch_size=batch_size, + num_workers=num_workers, + ) if wordid_to_phonemes_file is not None: self.set_wordid_to_phonemes(wordid_to_phonemes_file) diff --git a/nemo/collections/tts/g2p/models/t5.py b/nemo/collections/tts/g2p/models/t5.py index 19f976081687..4c673b18dc4a 100644 --- a/nemo/collections/tts/g2p/models/t5.py +++ b/nemo/collections/tts/g2p/models/t5.py @@ -17,8 +17,8 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from transformers import AutoTokenizer, T5ForConditionalGeneration from nemo.collections.asr.metrics.wer import word_error_rate diff --git a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py index 985897d8df3f..2fe0ac3f6077 100644 --- a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py +++ b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py @@ -93,7 +93,7 @@ def __init__( self.ascii_letter_dict = { x: ascii_letter_prefix + x for x in get_grapheme_character_set(locale="en-US", case=ascii_letter_case) } - self.ascii_letter_list = sorted(self.ascii_letter_dict) + self.ascii_letter_list = sorted(self.ascii_letter_dict.values()) self.ascii_letter_case = ascii_letter_case if apply_to_oov_word is None: @@ -181,6 +181,7 @@ def __call__(self, text: str) -> List[str]: `['wo3', 'jin1', 'tian1', 'qu4', 'le5', 'A', 'p', 'p', 'l', 'e', ' ', 'S', 't', 'o', 'r', 'e', ',', ' ', 'mai3', 'le5', 'yi2', 'ge4', 'i', 'P', 'h', 'o', 'n', 'e', '。']` """ + err = False text = set_grapheme_case(text, case=self.ascii_letter_case) pinyin_seq = [] @@ -201,7 +202,15 @@ def __call__(self, text: str) -> List[str]: tone_hyp = pinyin[-1] if tone_hyp in self.tone_dict: syllable = pinyin[:-1] - assert syllable in self.phoneme_dict, f"Syllable <{syllable}> does not exist in the dictionary." + # TODO: skipping the syllable that does not exist in the dictionary will lead to deletion errors in the + # synthesized speech. Even though this case is uncommon, it should be fixed in future. + if syllable not in self.phoneme_dict: + err = True + logging.error( + f"Syllable <{syllable}> does not exist in the dictionary. You should expect symbol " + f"deletion risks!!" + ) + continue phoneme_seq += self.phoneme_dict[syllable] phoneme_seq.append(self.tone_dict[tone_hyp]) # All pinyin would end up with a number in 1-5, which represents tones of the pinyin. @@ -211,4 +220,6 @@ def __call__(self, text: str) -> List[str]: phoneme_seq.append(self.ascii_letter_dict[tone_hyp]) else: phoneme_seq.append(pinyin) + if err: + logging.error(f"|{text}| contained unknown syllables") return phoneme_seq diff --git a/nemo/collections/tts/models/aligner.py b/nemo/collections/tts/models/aligner.py index d8e65d6e6821..5fea8615f7f2 100644 --- a/nemo/collections/tts/models/aligner.py +++ b/nemo/collections/tts/models/aligner.py @@ -18,9 +18,9 @@ import omegaconf import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import WandbLogger from torch import nn from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 0c5e41157613..230a24e36cb0 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -21,8 +21,8 @@ import torch.nn.functional as F from einops import rearrange from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.tts.losses.audio_codec_loss import ( FeatureMatchingLoss, diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index b1e702c89124..34213303abf4 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -18,9 +18,9 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger from nemo.collections.common.parts.preprocessing import parsers from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss diff --git a/nemo/collections/tts/models/fastpitch_ssl.py b/nemo/collections/tts/models/fastpitch_ssl.py index fe743edf8783..f2384c41c5b5 100644 --- a/nemo/collections/tts/models/fastpitch_ssl.py +++ b/nemo/collections/tts/models/fastpitch_ssl.py @@ -16,9 +16,9 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger from nemo.collections.tts.losses.fastpitchloss import DurationLoss, MelLoss, PitchLoss from nemo.collections.tts.modules.fastpitch import FastPitchSSLModule, average_features @@ -34,7 +34,7 @@ class FastPitchModel_SSL(ModelPT): """ FastPitch based model that can synthesize mel spectrograms from content and speaker embeddings - obtained from SSLDisentangler. This model can be used for voice conversion by swapping the speaker embedding + obtained from SSLDisentangler. This model can be used for voice conversion by swapping the speaker embedding of a given source utterance, with the speaker embedding of a target speaker. """ @@ -133,9 +133,21 @@ def tb_logger(self): return self._tb_logger def forward( - self, *, enc_out=None, enc_mask=None, durs=None, pitch=None, pace=1.0, + self, + *, + enc_out=None, + enc_mask=None, + durs=None, + pitch=None, + pace=1.0, ): - return self.fastpitch(enc_out=enc_out, enc_mask=enc_mask, durs=durs, pitch=pitch, pace=pace,) + return self.fastpitch( + enc_out=enc_out, + enc_mask=enc_mask, + durs=durs, + pitch=pitch, + pace=pace, + ) def compute_encoding(self, content_embedding, speaker_embedding, dataset_id=None): # content embedding is (B, C, T) @@ -177,7 +189,11 @@ def training_step(self, batch, batch_idx): enc_mask = enc_mask[:, :, None] mels_pred, _, _, log_durs_pred, pitch_pred, pitch = self( - enc_out=enc_out, enc_mask=enc_mask, durs=durs, pitch=pitch, pace=1.0, + enc_out=enc_out, + enc_mask=enc_mask, + durs=durs, + pitch=pitch, + pace=1.0, ) loss = 0 @@ -208,7 +224,10 @@ def training_step(self, batch, batch_idx): ) spec_predict = mels_pred[0].data.cpu().float().numpy() self.tb_logger.add_image( - "train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", + "train_mel_predicted", + plot_spectrogram_to_numpy(spec_predict), + self.global_step, + dataformats="HWC", ) return loss @@ -286,7 +305,10 @@ def on_validation_epoch_end(self, outputs): ) spec_predict = spec_predict[_rand_idx].data.cpu().float().numpy() self.tb_logger.add_image( - "val_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", + "val_mel_predicted", + plot_spectrogram_to_numpy(spec_predict), + self.global_step, + dataformats="HWC", ) if self.pitch_conditioning: @@ -321,10 +343,10 @@ def generate_wav( ): """ Args: - content_embedding : Content embedding from SSL backbone (B, C, T) + content_embedding : Content embedding from SSL backbone (B, C, T) speaker_embedding : Speaker embedding from SSL backbone (B, C) pitch_contour : Normalized Pitch contour derived from the mel spectrogram - encoded_len: Length of each content embedding, optional if batch size is 1. + encoded_len: Length of each content embedding, optional if batch size is 1. compute_pitch: if true, predict pitch contour from content and speaker embedding. compute_duration: if true, predict duration from content and speaker embedding. durs_gt: Ground truth duration of each content embedding, ignored if compute_duration is True. diff --git a/nemo/collections/tts/models/hifigan.py b/nemo/collections/tts/models/hifigan.py index 7a9a6d30671f..1a5462349c4d 100644 --- a/nemo/collections/tts/models/hifigan.py +++ b/nemo/collections/tts/models/hifigan.py @@ -18,8 +18,8 @@ import torch import torch.nn.functional as F from hydra.utils import instantiate +from lightning.pytorch.loggers.wandb import WandbLogger from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.loggers.wandb import WandbLogger from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss from nemo.collections.tts.models.base import Vocoder @@ -313,7 +313,7 @@ def stft(x): comp = torch.stft(x.squeeze(1), n_fft=1024, hop_length=256, win_length=1024, return_complex=True) comp = torch.view_as_real(comp) real, imag = comp[..., 0], comp[..., 1] - mags = torch.sqrt(real ** 2 + imag ** 2) + mags = torch.sqrt(real**2 + imag**2) phase = torch.atan2(imag, real) return mags, phase diff --git a/nemo/collections/tts/models/mixer_tts.py b/nemo/collections/tts/models/mixer_tts.py index c260df22e3c0..58b7f6f9706b 100644 --- a/nemo/collections/tts/models/mixer_tts.py +++ b/nemo/collections/tts/models/mixer_tts.py @@ -20,9 +20,9 @@ import transformers import wandb from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import WandbLogger from torch import nn from torch.nn import functional as F from transformers import AlbertTokenizer diff --git a/nemo/collections/tts/models/radtts.py b/nemo/collections/tts/models/radtts.py index 82f85d1ed6a2..3f04f2ca3908 100644 --- a/nemo/collections/tts/models/radtts.py +++ b/nemo/collections/tts/models/radtts.py @@ -15,9 +15,9 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import BaseTokenizer from nemo.collections.tts.losses.radttsloss import AttentionBinarizationLoss, RADTTSLoss diff --git a/nemo/collections/tts/models/spectrogram_enhancer.py b/nemo/collections/tts/models/spectrogram_enhancer.py index 65934d9a10ce..3644a77eb6fe 100644 --- a/nemo/collections/tts/models/spectrogram_enhancer.py +++ b/nemo/collections/tts/models/spectrogram_enhancer.py @@ -43,9 +43,9 @@ import torch.nn.functional as F from einops import rearrange from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from torch.utils.tensorboard.writer import SummaryWriter from nemo.collections.common.parts.utils import mask_sequence_tensor diff --git a/nemo/collections/tts/models/speechllm/__init__.py b/nemo/collections/tts/models/speechllm/__init__.py new file mode 100644 index 000000000000..9df65818d226 --- /dev/null +++ b/nemo/collections/tts/models/speechllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py new file mode 100644 index 000000000000..658ace21726f --- /dev/null +++ b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py @@ -0,0 +1,444 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import torch +from lightning.pytorch.trainer.trainer import Trainer +from omegaconf.dictconfig import DictConfig +from torch import Tensor + +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.collections.nlp.metrics.prompt_learning_metrics import AccuracyScore, BLEUScore, ROUGEScores +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common import ( + PromptEncoder, + PromptEncoderType, + VirtualPromptPlaceholderToken, + VirtualPromptSource, + VirtualPromptStyle, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import TextGeneration +from nemo.collections.nlp.parts import utils_funcs +from nemo.utils import AppState + +try: + from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +__all__ = ['MegatronBaseSpeechLM'] + + +class MegatronBaseSpeechLM(MegatronBaseModel, TextGeneration): + """ + Model class for prompt-tuning or p-tuning a pretrained Megatron model. + + Prompt Tuning initalizes virtual prompt embeddings directly from a copy of + certain token embeddings from the the pretrained model's vocabulary + and directly tunes these embedding weights. The token embeddings used in + initalization are specified by the user in the config file. The model can + be prompt-tuned for multiple tasks at once. virtual prompts are stored in a + prompt table and can be added or deleted without disrupting virtual prompts + for other tasks. + + P-tuning initializes an LSTM encoder model that generates virtual prompt + embeddings for every task. Each task shares the same encoder. After ptuning + is compelete, the learned virtual prompts can be saved to the prompt table + using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a + new virtual prompt via p-tuning, they do not need to retrain on all previous + tasks. This gives p-tuning the same task flexiblity as prompt-tuning. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer) + self.init_model(cfg, trainer) + self.config = self.model_parallel_config + + def init_model(self, cfg: DictConfig, trainer: Trainer): + self.cfg = cfg + + self.load_frozen_model(cfg, trainer) + self.prompt_encoder = None + self.tokenizer = self.frozen_model.tokenizer + + if hasattr(self.frozen_model.cfg, "encoder") and hasattr(self.frozen_model.cfg, "decoder"): + self.hidden_size = ( + self.frozen_model.cfg.encoder.hidden_size + ) # Encoder and decoder need to have the same hidden size and we check for this in the frozen enc-dec model. + else: + self.hidden_size = self.frozen_model.cfg.hidden_size + + self.existing_tasks = list(self.cfg.get('existing_tasks', [])) + self.new_tasks = list(self.cfg.get('new_tasks', [])) + self.virtual_prompt_style = VirtualPromptStyle(cfg.virtual_prompt_style) + + # Load templates for assigning virtual prompt token positions + self.load_task_templates(self.cfg.task_templates) + + if self.first_stage_of_pipeline() and self.virtual_prompt_style in [ + VirtualPromptStyle.P_TUNING, + ]: + # TODO: Handle this when moving GPT prompt learning to the base class. + self.word_embeddings = self.frozen_model.enc_dec_model.encoder_embedding.word_embeddings + + # P-Tuning uses an LSTM Encoder to produce virtual token embeddings + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + self.virtual_prompt_source = VirtualPromptSource.PROMPT_ENCODER + elif self.virtual_prompt_style == VirtualPromptStyle.NO_PROMPT: + self.virtual_prompt_source = VirtualPromptSource.NO_PROMPT + else: + raise ValueError(f"\nvirtual prompt style '{cfg.virtual_prompt_style}'") + + self._reduced_loss_buffer = [] + self._inference_config = None + + # Prepare pseudo token ids for virtual/virtual prompt tokens + self.pseudo_tokens = get_pseudo_tokens(self.max_virtual_tokens) + if isinstance(self.tokenizer, SentencePieceTokenizer): + self.tokenizer.add_special_tokens(self.pseudo_tokens) + else: + self.tokenizer.add_special_tokens({'additional_special_tokens': self.pseudo_tokens}) + self.pseudo_token_ids = self.tokenizer.tokens_to_ids(self.pseudo_tokens) + self.pseudo_token_ids_start = self.pseudo_token_ids[0] if self.pseudo_token_ids else None + self.pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id is not None else self.tokenizer.unk_id + self.decoder_seq_length = cfg.get('decoder_seq_length', 40) + + self.autocast_dtype = utils_funcs.torch_dtype_from_precision(self.cfg.precision) # Mixed precision datatype + # make sure the default pytorch lightning gradient clipping in the basemodel + self.grad_clip_pl_default = True + self.lowest_val_loss = None + self.prompt_encoder = None + + self.enable_autocast = not self.megatron_amp_O2 and self.autocast_dtype in [torch.float16, torch.bfloat16] + + # define validation metric + if self.cfg.get('report_validation_metric', False): + validation_metric = self.cfg.get('validation_metric', 'accuracy') + if validation_metric == 'accuracy': + self.validation_metric = AccuracyScore() + elif validation_metric == 'bleu': + self.validation_metric = BLEUScore() + elif validation_metric == 'rouge': + self.validation_metric = ROUGEScores() + + def load_task_templates(self, task_templates): + """ + Takes in the task template portion of the config and turns + it into a table where each task's prompt template and + the number of virtual tokens to insert in a given part of + the prompt template are specified. + """ + self.task_templates = {} + self.task_id_num_to_name = {} + self.max_virtual_tokens = 0 + + task_id_num = 0 + for task in task_templates: + self.task_templates[task.taskname] = { + "prompt_template": task.prompt_template, + "prompt_template_fields": re.findall("\{(.*?)\}", task.prompt_template), + "answer_only_loss": task.get("answer_only_loss", False), + "answer_field": task.get("answer_field", None), + "truncate_field": task.truncate_field, + "total_virtual_tokens": task.total_virtual_tokens, + "virtual_token_splits": task.virtual_token_splits, + "task_id_num": task_id_num, + } + + self.max_virtual_tokens = max(self.max_virtual_tokens, task.total_virtual_tokens) + self.task_id_num_to_name[task_id_num] = task.taskname + task_id_num += 1 + + # Check that all new tasks have the same total num virtual tokens + # Num virtual tokens for new tasks don't need to match num used for previously tuned tasks + if self.new_tasks: + new_task_name = self.new_tasks[0] + self.total_new_task_virtual_tokens = self.task_templates[new_task_name]["total_virtual_tokens"] + + assert all( + self.task_templates[taskname]["total_virtual_tokens"] == self.total_new_task_virtual_tokens + for taskname in self.new_tasks + ), "Total virtual tokens for each task tuned simultaneously must match. If you want to use a different number of virtual tokens for different tasks, tune them separately." + + def init_prompt_encoder(self): + """ + Init the prompt encoder needed for p-tuning on a new task + """ + # Total virtual tokens should be the same across all new tasks, so just need one + new_task = self.new_tasks[0] + total_virtual_tokens = self.task_templates[new_task]["total_virtual_tokens"] + + encoder_type = PromptEncoderType(self.cfg.p_tuning.get("encoder_type", "tpmlp").lower()) + self.prompt_encoder = PromptEncoder( + config=self.model_parallel_config, + encoder_type=encoder_type, + total_virtual_tokens=total_virtual_tokens, + token_dim=self.hidden_size, + hidden_size=self.cfg.p_tuning.get("encoder_hidden", self.hidden_size // 2), + lstm_dropout=self.cfg.p_tuning.get("dropout", 0.0), + num_layers=self.cfg.p_tuning.get("num_layers", 2), + init_std=self.cfg.p_tuning.get("init_std", 0.023), + taskname=new_task, + ) + + def freeze_existing_word_embeddings(self): + """Freeze params of existing virtual prompts that should not be tuned further""" + # Make sure word embeddings are frozen + for params in self.word_embeddings.parameters(): + params.requires_grad = False + + def state_dict(self): + """ + Custom state dict that only contains prompt table and prompt encoder parameters. + No frozen model parameters are stored in the state dict. Prompt encoder parameters + are only in state dict for intermediate checkpoints saved during training. Final + nemo checkpoints at the end of training will contain prompt table parameters only. + """ + state_dict_ = {} + state_dict_["frozen_model_enc_dec_model"] = self.frozen_model.enc_dec_model.state_dict() + state_dict_["word_embeddings"] = self.word_embeddings.state_dict() + if self.prompt_encoder is not None: + state_dict_["prompt_encoder"] = self.prompt_encoder.state_dict() + + return state_dict_ + + def load_state_dict(self, state_dict, strict: bool = True): + """ + Custom load state dict method that only loads prompt table and prompt encoder + parameters. Matching load method for this class' custom state dict method. + """ + self.init_prompt_encoder() + self.frozen_model.enc_dec_model.load_state_dict(state_dict["frozen_model_enc_dec_model"], strict) + self.word_embeddings.load_state_dict(state_dict["word_embeddings"], strict) + if 'prompt_encoder' in state_dict: + self.prompt_encoder.load_state_dict(state_dict["prompt_encoder"], strict) + + # Not sure why when we resume training the prompt encoder is on cpu + # Because it's not created on init - Should really be moved to init + self.prompt_encoder.to("cuda") + + def embed_input(self, input_ids: Tensor, taskname_ids: Tensor, use_cached_reps: bool): + """ + Replaces the virtual tokens in the input_ids with embeddings + calculated from either the 'prompt_table' or 'prompt_encoder'. + The virtual token placeholders have token_ids listed in + `self.pseudo_token_ids`. + + params: + input_ids: the input token ids + taskname_ids: the NLP task tag token ids + returns: + the token embedding for the LM model. + """ + # Replace virtual token ids with padding for forward pass through vocab embeddings + discrete_token_ids = input_ids.clone() + discrete_token_ids[(input_ids >= self.pseudo_token_ids_start)] = self.pad_token_id + discrete_token_embeds = self.word_embeddings(discrete_token_ids).clone() + + # Find the indicies where virtual tokens should be inserted + virtual_token_locations = input_ids >= self.pseudo_token_ids_start + + # If there are no virtual tokens, just return discrete token embeds + if not virtual_token_locations.any(): + return discrete_token_embeds + + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + # taskname_embeddings = self.word_embeddings(taskname_ids) + batch_size, _ = taskname_ids.size() + virtual_token_embeds = self.prompt_encoder(batch_size=batch_size, use_cached_reps=use_cached_reps) + else: + raise ValueError("invalid VirtualPromptSource.") + + # Create index template specifying where virtual token embeddings should be placed + batch_size, _, embedding_size = discrete_token_embeds.shape + virtual_token_index = virtual_token_locations.nonzero().reshape((batch_size, -1, 2))[:, :, 1][:, :, None] + virtual_token_index = virtual_token_index.expand( + batch_size, self.total_new_task_virtual_tokens, embedding_size + ) + + # Make sure discrete_token_embeds and virtual_token_embeds share the same dtype + discrete_token_embeds = discrete_token_embeds.type(virtual_token_embeds.dtype) + + # Insert virtual token embeddings where they belong amoung the discrete token embeddings + discrete_token_embeds.scatter_(1, virtual_token_index, virtual_token_embeds) + input_embeds = discrete_token_embeds + + return input_embeds + + def on_train_end(self): + # Save p-tuned prompts to prompt table for inference or future task training + self.save_to(save_path=self.cfg.nemo_path) + + def setup(self, stage=None): + if stage == 'predict' and self.first_stage_of_pipeline(): + return + + self.setup_test_data() + if stage == 'test': + return + + if self.first_stage_of_pipeline(): + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + if self.prompt_encoder is None: + self.init_prompt_encoder() + + self.setup_training_data() + self.setup_validation_data() + + def setup_training_data(self, training_data_config=None): + if self.cfg.data.get('train_ds', None): + self._train_ds, self._train_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.train_ds, + batch_size=self.cfg.global_batch_size, + for_train=True, + drop_last=True, + shuffle=True, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('train_manifest', None): + self._train_ds, self._train_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.train_manifest, + audio_path=self.cfg.data.train_audio_path, + batch_size=self.cfg.global_batch_size, + for_train=True, + drop_last=True, + shuffle=self.cfg.data.shuffle, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def setup_validation_data(self, validation_data_config=None): + if self.cfg.data.get('validation_ds', None): + self._validation_ds, self._validation_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.validation_ds, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=True, + drop_last=self.cfg.get("validation_drop_last", True), + shuffle=False, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('validation_manifest', None): + self._validation_ds, self._validation_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.validation_manifest, + audio_path=self.cfg.data.validation_audio_path, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=True, + drop_last=self.cfg.get("validation_drop_last", True), + shuffle=0, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def setup_test_data(self, test_data_config=None): + if self.cfg.data.get('test_ds', None): + self._test_ds, self._test_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.test_ds, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=False, + drop_last=False, + shuffle=False, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('test_manifest', None): + self._test_ds, self._test_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.test_manifest, + audio_path=self.cfg.data.test_audio_path, + batch_size=self.cfg.global_batch_size, + for_train=False, + drop_last=False, + shuffle=0, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def _reconfigure_and_process_inference_batch(self, global_batch_size_per_gpu, gbs): + # This should happen only on the last batch of the dataset. + if global_batch_size_per_gpu != gbs // parallel_state.get_data_parallel_world_size(): + # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def _reconfigure_batch_sizes(self, gbs: int, mbs: int): + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=gbs, + micro_batch_size=mbs, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def set_inference_config(self, inference_config): + self._inference_config = inference_config + + def get_inference_config(self): + return self._inference_config + + def set_input_tensor(self, input_tensor): + pass + + def first_stage_of_pipeline(self): + pass + + @classmethod + def list_available_models(cls): + pass + + def load_frozen_model(self, cfg, trainer): + pass + + +def get_pseudo_tokens(num_virtual_tokens): + """ + Takes in an integer and returns a list of strings where each string + is a numbered virtual token placeholder. If + num_virtual_tokens = 3, then this function returns: + + ["", "", ""] + + Args: + num_virtual_tokens: (int) Number of virtual token strings you want to make + + returns a list of string. + + """ + pseudo_tokens = [ + VirtualPromptPlaceholderToken.BASE.value + str(i) + VirtualPromptPlaceholderToken.END.value + for i in range(num_virtual_tokens) + ] + + return pseudo_tokens diff --git a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py new file mode 100644 index 000000000000..d35d53b3cac7 --- /dev/null +++ b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py @@ -0,0 +1,2672 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import os +import random +import string +from functools import partial +from typing import Any, List + +import editdistance +import imageio +import numpy as np +import soundfile as sf +import torch +from lightning.pytorch.trainer.trainer import Trainer +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from omegaconf.omegaconf import open_dict + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceSpeechLLMTTSTokenizer +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model +from nemo.collections.nlp.models.language_modeling.megatron_t5_sft_model import MegatronT5SFTModel +from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import ( + MegatronTokenLevelEncoderDecoderSpeechLLMModule, +) +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + init_method_normal, +) +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.collections.tts.data.speechllm.t5_speechllm_dataset import Lang, T5SpeechLMDataset +from nemo.collections.tts.data.speechllm.t5_speechllm_tarred_dataset import T5SpeechLMTarredDataset +from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.speechllm.megatron_base_speechllm_prompt_model import MegatronBaseSpeechLM +from nemo.collections.tts.parts.utils.helpers import plot_alignment_to_numpy_for_speechllm, plot_codec_to_numpy +from nemo.utils import AppState, logging + +try: + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.enums import ModelType + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +import time + +import librosa +from torchaudio.pipelines import SQUIM_SUBJECTIVE +from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector + +__all__ = ['MegatronT5SpeechLMModel'] + + +class MegatronT5OverrideModel(MegatronT5Model): + def _build_tokenizer(self): + if self._cfg.tokenizer.library == "sentencepiece": + if hasattr(self._cfg.tokenizer, "sentencepiece_legacy"): + legacy = self._cfg.tokenizer.sentencepiece_legacy + else: + legacy = True if self._cfg.tokenizer.library == 'sentencepiece' else False + self.tokenizer = SentencePieceSpeechLLMTTSTokenizer( + model_path=self.register_artifact("tokenizer.model", self._cfg.tokenizer.get('model', None)), + legacy=legacy, + ) + + if self._cfg.tokenizer.get('additional_special_tokens', None) is not None: + tokens_list = OmegaConf.to_object(self._cfg.tokenizer.additional_special_tokens) + self.tokenizer.add_special_tokens(tokens_list) + else: + super()._build_tokenizer() + + def model_provider_func(self, pre_process, post_process, add_encoder, add_decoder): + if not hasattr(self.cfg, 'encoder') or not hasattr(self.cfg, 'decoder'): + logging.warning( + 'Could not find encoder or decoder in config. This is probably because of restoring an old checkpoint. Copying shared model configs to encoder and decoder configs.' + ) + # After the call below, self.cfg.encoder and self.cfg.decoder will be populated with the cfg.model configs from old checkpoints. + self._populate_encoder_decoder_configs_for_backward_compatibility(self.cfg) + + if parallel_state.get_pipeline_model_parallel_world_size() > 1 and self.cfg.encoder.arch == 'perceiver': + raise ValueError(f"Perceivers with pipeline parallel > 1 is not supported yet.") + + if not hasattr(self.cfg, 'embedding_init_method_std'): + embedding_init_method_std = self.cfg.encoder.init_method_std + else: + embedding_init_method_std = self.cfg.embedding_init_method_std + + if not hasattr(self.cfg, 'embedding_dropout'): + embedding_dropout = self.cfg.encoder.hidden_dropout + else: + embedding_dropout = self.cfg.embedding_dropout + + model = MegatronTokenLevelEncoderDecoderSpeechLLMModule( + config=self.model_parallel_config, + encoder_cfg=self.cfg.encoder, + decoder_cfg=self.cfg.decoder, + vocab_size=self.padded_vocab_size, + max_position_embeddings=self.cfg.max_position_embeddings, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + fp16_cross_entropy=self.cfg.get('fp16_lm_cross_entropy', False), + precision=self.cfg.get('precision', 16), + embedding_init_method_std=embedding_init_method_std, + embedding_dropout=embedding_dropout, + label_smoothing=self.cfg.get('label_smoothing', 0.0), + add_encoder=add_encoder, + add_decoder=add_decoder, + share_token_embeddings=self.cfg.get('share_token_embeddings', True), + share_decoder_tokens_head_embeddings=self.cfg.get('share_decoder_tokens_head_embeddings', True), + tokens_head_bias=self.cfg.get('tokens_head_bias', True), + hiddens_cfg=self.cfg.get('hiddens', None), + ) + return model + + +class MegatronT5SpeechLMModel(MegatronBaseSpeechLM): + """ + Model class for prompt-tuning or p-tuning a pretrained Megatron T5 model. + + Prompt Tuning initializes virtual prompt embeddings directly from a copy of + certain token embeddings from the pretrained T5 model's vocabulary + and directly tunes these embedding weights. The token embeddings used in + initialization are specified by the user in the config file. The model can + be prompt-tuned for multiple tasks at once. Virtual prompts are stored in a + prompt table and can be added or deleted without disrupting virtual prompts + for other tasks. + + P-tuning initializes an LSTM encoder model that generates virtual prompt + embeddings for every task. Each task shares the same encoder. After p-tuning + is complete, the learned virtual prompts can be saved to the prompt table + using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a + new virtual prompt via p-tuning, they do not need to retrain on all previous + tasks. This gives p-tuning the same task flexibility as prompt-tuning. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer) + self.model_type = ModelType.encoder_and_decoder + speech_codebook_size = cfg.data.get('speech_codebook_size', 1024) + num_speech_codebooks = cfg.data.get('num_speech_codebooks', 8) + speech_offset = cfg.data.get('speech_offset', 30000) + codecmodel_type = cfg.get('codecmodel_type', 'nemo_codec') + attn_prior_scaledown_start_step = cfg.get('attn_prior_scaledown_start_step', 10000) + attn_prior_end_step = cfg.get('attn_prior_end_step', 11000) + num_cross_attention_heads = cfg.get('num_cross_attention_heads', 12) + self.lm_vocab_size = cfg.get('lm_vocab_size', 30000) + self.context_pattern = cfg.data.get('context_pattern', 'parallel') + self.context_conditioning = cfg.get('context_conditioning', "decoder") + self.context_duration_min = cfg.data.get('context_duration_min', 2.9) + self.context_duration_max = cfg.data.get('context_duration_max', 2.9) + self.codebook_fps = cfg.data.get('codebook_fps', 86) + self.decoder_context_len = 0 + if self.context_conditioning == "decoder": + assert self.context_duration_min == self.context_duration_max, "Decoder context duration must be fixed" + self.decoder_context_len = int(self.codebook_fps * self.context_duration_min) + + self.speech_offset = speech_offset + self.speech_codebook_size = speech_codebook_size + self.num_speech_codebooks = num_speech_codebooks + self.codecmodel_type = codecmodel_type + self.enc_output_to_layers = cfg.get('enc_output_to_layers', None) + if self.enc_output_to_layers is not None: + # Convert from listconfig to list + self.enc_output_to_layers = [[l for l in encoder_layer] for encoder_layer in self.enc_output_to_layers] + + self.frozen_model.enc_dec_model.speech_offset = speech_offset + self.frozen_model.enc_dec_model.speech_codebook_size = speech_codebook_size + self.frozen_model.enc_dec_model.num_speech_codebooks = num_speech_codebooks + self.frozen_model.enc_dec_model.seq_pattern = cfg.get('seq_pattern', 'parallel') + self.frozen_model.enc_dec_model.attn_prior_scaledown_start_step = attn_prior_scaledown_start_step + self.frozen_model.enc_dec_model.attn_prior_end_step = attn_prior_end_step + self.frozen_model.enc_dec_model.alignment_decoder_layerids = cfg.get( + 'alignment_decoder_layerids', list(range(0, 12)) + ) + self.frozen_model.enc_dec_model.return_all_crossattention_probs = cfg.get( + 'return_all_crossattention_probs', False + ) + self.frozen_model.enc_dec_model.num_cross_attention_heads = num_cross_attention_heads + self.frozen_model.enc_dec_model.context_conditioning = self.context_conditioning + self.frozen_model.enc_dec_model.decoder_context_len = self.decoder_context_len + self.frozen_model.enc_dec_model.enc_output_to_layers = self.enc_output_to_layers + + self.alignment_loss_start_step = 0 + self.alignment_loss_end_step = float('inf') + self.use_alignment_loss = cfg.get('use_alignment_loss', False) + if self.use_alignment_loss: + alignment_loss_scale = cfg.get('alignment_loss_scale', 1.0) + self.frozen_model.enc_dec_model.use_alignment_loss = True + self.frozen_model.enc_dec_model.forward_sum_loss = ForwardSumLoss(loss_scale=alignment_loss_scale) + self.frozen_model.enc_dec_model.alignment_text_end_offset = cfg.get('alignment_text_end_offset', 0) + self.frozen_model.enc_dec_model.align_every_n_head = cfg.get('align_every_n_head', 1) + self.alignment_loss_start_step = cfg.get('alignment_loss_start_step', 0) + self.alignment_loss_end_step = cfg.get('alignment_loss_end_step', float('inf')) + + # Need to explicitly set this since it is already initialized + self.frozen_model.enc_dec_model.tokens_head.parallel_output = self.frozen_model.enc_dec_model.parallel_output + + list_of_speech_heads = [] + list_of_speech_tokens_embeddings = [] + for _ in range(self.num_speech_codebooks - 1): + # init is NOT used since we overwrite the weight below anyways + _speech_head_embedding = tensor_parallel.VocabParallelEmbedding( + speech_codebook_size, + embedding_dim=self.word_embeddings.embedding_dim, + init_method=lambda x: x.data.fill_(0), + config=self.model_parallel_config, + ) + _speech_head_embedding.weight.data.fill_(0) + _speech_head_embedding.shared = True + list_of_speech_tokens_embeddings.append(_speech_head_embedding) + # Linear layer that maps from hidden size to speech codebook size + hidden_size = self.frozen_model.enc_dec_model.decoder_cfg.hidden_size + init_method_std = self.frozen_model.enc_dec_model.decoder_cfg.init_method_std + # Changing to ColumnParallelLinear instead of Linear to support 3b Tensor Parallelism + _speech_head = tensor_parallel.ColumnParallelLinear( + input_size=hidden_size, + output_size=speech_codebook_size, + bias=True, + gather_output=not self.frozen_model.enc_dec_model.parallel_output, + init_method=init_method_normal(init_method_std), + config=self.model_parallel_config, + ) + list_of_speech_heads.append(_speech_head) + + self.frozen_model.enc_dec_model.speech_tokens_heads = torch.nn.ModuleList(list_of_speech_heads) + self.frozen_model.enc_dec_model.speech_tokens_embeddings = torch.nn.ModuleList( + list_of_speech_tokens_embeddings + ) + + self.sample_rate = 24000 + if codecmodel_type == 'nemo_codec': + codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path')) + codec_model.to('cuda') + codec_model.eval() + self.sample_rate = 22050 + else: + raise NotImplementedError() + + self.additional_models = {'codec': codec_model} + self.train_check_interval = self.cfg.get('train_check_interval', 500) + self.plot_alignments_sliced = self.cfg.get('plot_alignments_sliced', True) + app_state = AppState() + self.is_rank_zero = app_state.global_rank == 0 + self.predict_step_outputs = [] + self.phoneme_tokenizer = None + + # classifier-free guidance (CFG) option during training. The probability (0.0 <= ε <= 1.0) is used to trigger the action that the + # text or audio tokens in a batch are replaced by [UNK], such that mimicking the text- or audio-free scenario. + # If a random number is greater than ε, then keep text or audio tokens as-is, otherwise, the text or audio tokens are + # replaced by [UNK]. Default to 0.0, meaning CFG is disabled. + self.train_text_cfg_prob = cfg.get('train_text_cfg_prob', 0.0) + self.train_audio_cfg_prob = cfg.get('train_audio_cfg_prob', 0.0) + self._rng = random.Random() + + # control the strength of the classifier guidance during inference, Logits_cfg = w*Logits_cond + (1-w)*Logits_uncond, + # equivalent to Logits_cfg = Logits_cond + alpha*(Logits_cond - Logits_uncond) where alpha=w-1. + # Default w to 1.O, indicating no interpolation is applied. + self.inference_cfg_interpolation_scale = cfg.get('inference_cfg_interpolation_scale', 1.0) + self.inference_apply_text_cfg = cfg.get('inference_apply_text_cfg', False) + self.inference_apply_audio_cfg = cfg.get('inference_apply_audio_cfg', False) + if self.inference_cfg_interpolation_scale == 1.0: + self.inference_apply_text_cfg = False + self.inference_apply_audio_cfg = False + + # whether to apply cfg filter to address faster speech rate. + self.inference_apply_cfg_filter = cfg.get("inference_apply_cfg_filter", False) + + # this scale is suggested to be smaller than `self.question_guidance_scale` and it is used to balance the weights + # between the conditioned logits after applying cfg filter and the original unconditioned logits. Default to 1.0, + # indicating only conditioned logits are used. + if not self.inference_apply_cfg_filter: + self.inference_cfg_filter_interpolation_scale = None + else: + self.inference_cfg_filter_interpolation_scale = cfg.get('inference_cfg_filter_interpolation_scale', 1.0) + + # whether to estimate MOS in predict_step. + self.estimate_mos = cfg.get('estimate_mos', True) + if self.estimate_mos: + # requires to specify a non-matching high-quality and clean reference audio file. It is used to estimate MOS. + self.non_matching_ref_audio_filepath = cfg.get('non_matching_ref_audio_filepath', None) + if self.non_matching_ref_audio_filepath is None: + raise ValueError( + f"Please provide a high-quality reference audio to estimate the MOS. Alternatively, " + f"set `model.estimate_mos=False` to disable MOS estimation." + ) + if not os.path.exists(self.non_matching_ref_audio_filepath): + raise FileNotFoundError( + f"Please provide a valid file path for a high-quality reference audio to estimate" + f" the MOS. Alternatively, set `model.estimate_mos=False` to disable MOS estimation." + ) + + def decode_wav_from_codec_model(self, codes): + codec_model = self.additional_models['codec'] + if self.codecmodel_type == 'nemo_codec': + codec_len = torch.Tensor([codes.shape[1]]).long().cuda() + if codec_len < 10: + # return a one-second silence + return torch.zeros(24000).cuda() + wav, _ = codec_model.decode(tokens=codes.unsqueeze(0), tokens_len=codec_len) + wav = wav[0] + else: + raise NotImplementedError() + return wav + + def first_stage_of_pipeline(self): + if self.frozen_model.enc_dec_model.pre_process and parallel_state.get_pipeline_model_parallel_rank() == 0: + return True + return False + + def forward( + self, + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_mask, + position_ids, + taskname_ids, + labels=None, + speech_mask=None, + inference=False, + inference_step=0, + cross_attention_prior=None, + text_limits=None, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + ): + """ + Special forward method for p-tuning/prompt-tuning pretrained + T5 style models. + """ + if isinstance(context_and_question_tokens, list): + multi_encoder = True + assert isinstance(enc_mask, list) + assert isinstance(position_ids, list) + if cross_attention_prior is None: + cross_attention_prior = [None for _ in range(len(context_and_question_tokens))] + assert isinstance(cross_attention_prior, list) + assert len(context_and_question_tokens) == len(enc_mask) == len(position_ids) == len(cross_attention_prior) + else: + multi_encoder = False + context_and_question_tokens = [context_and_question_tokens] + enc_mask = [enc_mask] + position_ids = [position_ids] + cross_attention_prior = [cross_attention_prior] + + enc_output = None + logging.debug( + f"self.first_stage_of_pipeline()={self.first_stage_of_pipeline()}\tinference_step={inference_step}" + ) + if self.first_stage_of_pipeline() and inference_step == 0: + # Get embeddings for text tokens and insert virtual token embeddings + encoder_input_list = [] + for ei in range(len(context_and_question_tokens)): + input_embeds = self.get_embeddings_and_combine( + [virtual_tokens, context_and_question_tokens[ei]], taskname_ids, inference + ) + # TODO: This check needs to be revisited with PP support. + if hasattr(self.frozen_model.enc_dec_model.encoder_embedding, 'position_embeddings'): + position_embeddings = self.frozen_model.enc_dec_model.encoder_embedding.position_embeddings( + position_ids[ei] + ) + encoder_input = input_embeds + position_embeddings + else: + encoder_input = input_embeds + encoder_input_list.append(encoder_input) + else: + encoder_input_list = None + encoder_input = None + if inference_step != 0: + enc_output = context_and_question_tokens if multi_encoder else context_and_question_tokens[0] + + # If the decoder input starts with instead of , which is the case for huggingface T5 models, we don't want to mask the first token. + # For NeMo-Megatron, the sequence starts with , which is never masked so we can always set index 0 to be unmasked. + dec_mask[:, 0] = 1 + + if not self.cfg.data.get('use_attention_prior', False): + cross_attention_prior = [None for _ in range(len(cross_attention_prior))] + + _encoder_input = encoder_input_list + if not multi_encoder: + enc_mask = enc_mask[0] + cross_attention_prior = cross_attention_prior[0] + _encoder_input = encoder_input_list[0] if encoder_input_list is not None else None + + # Call forward on T5 model with preprocessed embeddings + if inference and inference_step == 0: + set_inference_key_value_memory = True + else: + set_inference_key_value_memory = False + + if self.autocast_dtype == torch.float32: + output, out_logits = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=_encoder_input, + enc_output=enc_output, + speech_mask=speech_mask, + cross_attention_prior=cross_attention_prior, + text_limits=text_limits, + global_step=self.global_step, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + else: + with torch.autocast(device_type="cuda", dtype=self.autocast_dtype): + output, out_logits = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=_encoder_input, + enc_output=enc_output, + speech_mask=speech_mask, + cross_attention_prior=cross_attention_prior, + text_limits=text_limits, + global_step=self.global_step, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + + return output, encoder_input, out_logits + + def load_frozen_model(self, cfg, trainer): + self.megatron_amp_O2 = cfg.get('megatron_amp_o2', False) + + # TODO: Fix this once apex patches FusedScaledMaskedSoftmax. + # This is a workaround for the fact that `masked_softmax_fusion` has issues with certain input sizes that may be present while finetuning. + cfg_language_model_path = cfg.get('language_model_path', None) + cfg_frozen_model = cfg.get('frozen_model', None) + if not (bool(cfg_language_model_path) ^ bool(cfg_frozen_model)): + raise ValueError( + "T5-TTS requires either 'language_model_path' or 'frozen_model' in its config, but not both." + ) + + if cfg_language_model_path: + t5_cfg = MegatronT5Model.restore_from(cfg_language_model_path, trainer=trainer, return_config=True) + else: + t5_cfg = cfg_frozen_model + + OmegaConf.set_struct(t5_cfg, True) + with open_dict(t5_cfg): + if hasattr(t5_cfg, 'encoder') and hasattr(t5_cfg, 'decoder'): + t5_cfg.encoder.masked_softmax_fusion = False + t5_cfg.decoder.masked_softmax_fusion = False + else: + t5_cfg.masked_softmax_fusion = False + t5_cfg.megatron_amp_O2 = self.megatron_amp_O2 + # hack to make the _GLOBAL_NUM_MICROBATCHES_CALCULATOR initialize + t5_cfg.micro_batch_size = cfg.get('micro_batch_size', 4) + t5_cfg.global_batch_size = cfg.get('global_batch_size', 4) + t5_cfg.precision = trainer.precision + t5_cfg.tokenizer.num_sentinel_tokens = cfg.get('num_sentinel_tokens', 39184 - 29056) + t5_cfg.seq_length = cfg.data.max_seq_length + if cfg.get('max_position_embeddings', None) is None: + t5_cfg.max_position_embeddings = cfg.data.max_seq_length + else: + t5_cfg.max_position_embeddings = cfg.get('max_position_embeddings') + t5_cfg.use_flash_attention = cfg.get('use_flash_attention', False) + if cfg.get('override_token_model', None): + t5_cfg.tokenizer.model = cfg['override_token_model'] + if cfg.get('override_tokenizer_vocab_file', None): + t5_cfg.tokenizer.vocab_file = cfg['override_tokenizer_vocab_file'] + + if cfg.get('train_from_scratch', False): + print("Training from scratch!") + # Defaults for 220m model + # To override any of these, add +model.override_= to the config file. + # Eg. +model.override_hidden_size=1024 + overide_keys = [ + 'hidden_size', # 768 + 'num_layers', # 12 + 'num_attention_heads', # 12 + 'hidden_dropout', # 0.1 + 'attention_dropout', # 0.1 + 'kv_channels', # 64 + 'ffn_hidden_size', # 2048 + ] + # Defaults for 220m model + for k in overide_keys: + if cfg.get(f'override_{k}') is not None: + t5_cfg[k] = cfg.get(f'override_{k}') + + self.frozen_model = MegatronT5OverrideModel(t5_cfg, trainer=trainer) + num_params = sum(p.numel() for p in self.frozen_model.parameters() if p.requires_grad) + print(f"Number of parameters: {num_params}") + else: + print(f"Loading from pretrained checkpoint: {cfg_language_model_path}") + if cfg_language_model_path is None: + raise ValueError( + "T5-TTS SFT on pretrained model checkpoint requires `langauge_model_path` in its config." + ) + + self.frozen_model = MegatronT5OverrideModel.restore_from( + cfg_language_model_path, + trainer=trainer, + override_config_path=t5_cfg, + save_restore_connector=NLPSaveRestoreConnector(), + ) + + if not cfg.get('english_only_model', False): + self.frozen_model.tokenizer.add_phone_tokens_to_special_tokens() + + logging.info(f"self.frozen_model {self.frozen_model}") + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + """ + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + # Get seq length of batch + batch = next(dataloader_iter) + _, seq_length = batch[0].shape + if batch[4].dim() > 2: + _, _, dec_seq_length = batch[4].shape + else: + _, dec_seq_length = batch[4].shape + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only), + data_iterator=data_iter, + model=[self], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=get_micro_batch_size(), + decoder_seq_length=dec_seq_length, + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # we're not on the last pipeline stage so no losses + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean + + def convert_tokens_to_range(self, tokens, apply_offset_correction=True, pattern=None): + # convert tokens to range [0, 1024] + output_tokens = tokens.clone() + if apply_offset_correction: + output_tokens[0] = output_tokens[0] - self.speech_offset + output_tokens = torch.clamp(output_tokens, min=0, max=self.speech_codebook_size - 1) + if pattern is None: + pattern = self.cfg.get('seq_pattern', 'delay_parallel') + if pattern == "delay_parallel": + output_tokens_new = [] + for _c in range(output_tokens.shape[0]): + si = _c + ei = _c + output_tokens.shape[1] - self.num_speech_codebooks + output_tokens_new.append(output_tokens[_c, si:ei]) + output_tokens_new = torch.stack(output_tokens_new) + output_tokens = output_tokens_new + + return output_tokens + + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) + _batch = [] + for x in batch: + if isinstance(x, torch.Tensor): + x = x.cuda(non_blocking=True) + elif isinstance(x, list): + if isinstance(x[0], torch.Tensor): + x = [y.cuda(non_blocking=True) for y in x] + _batch.append(x) + batch = _batch + # batch = [x.cuda(non_blocking=True) if isinstance(x, torch.Tensor) else x for x in batch] + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, + _, # TODO: text limit and lang not in tarred dataset + _, + ) = batch + + if self.trainer.global_step % self.train_check_interval == 0 and not validation_step and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = True + + _cross_attention_prior = cross_attention_prior + if isinstance(context_and_question_tokens, list): + # None for context and prior for question + _cross_attention_prior = [None, cross_attention_prior] + + output_tensor, encoder_input, out_logits = model( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=labels, + speech_mask=speech_mask, + cross_attention_prior=_cross_attention_prior, + text_limits=text_limits, + inference=False, + ) + output_tensor = output_tensor.contiguous() + + alignment_loss = out_logits[3] + if alignment_loss is not None: + self.logger.experiment.add_scalar('train_alignment_loss', alignment_loss, self.global_step) + + if self.trainer.global_step % self.train_check_interval == 0 and not validation_step and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = False + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if torch.count_nonzero(speech_mask) == 0: + text_labels = labels[:, 0, :] # [B, 8, T] -> [B, T] + token_logits = out_logits[0] * 1 # [T, B, V] + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + token_logits = token_logits.argmax(dim=2) # [T, B] + token_logits = token_logits.t() # [B, T] + score = 0 + for i in range(text_labels.size()[0]): + r = text_labels[i].long() + nzm = r != 0 + r = r.tolist() + h = token_logits[i].long() * nzm + h = h.tolist() + score += editdistance.eval(r, h) + score /= text_labels.size()[0] + logging.info(f"wer score : {score}") + self.logger.experiment.add_scalar('WER', score, self.global_step) + else: + audio_len = ( + self.decoder_context_len + (labels[0][0][self.decoder_context_len :] != 0).sum().item() + ) + labels_to_1024 = self.convert_tokens_to_range(labels[0, :, 0:audio_len]) + label_wav = self.decode_wav_from_codec_model(labels_to_1024) + dec_input_to_1024 = self.convert_tokens_to_range(dec_input[0, :, 0:audio_len]) + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024) + self.logger.experiment.add_audio( + "train_label_wav", label_wav, self.global_step, self.sample_rate + ) + self.logger.experiment.add_audio( + "train_dec_input_wav", dec_input_wav, self.global_step, self.sample_rate + ) + if isinstance(context_and_question_tokens, list): + context_tokens = context_and_question_tokens[0] + question_tokens = context_and_question_tokens[1] + input_token_list_all = [ + question_tokens[0, 0, i].item() for i in range(question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) + for ti, t in enumerate(input_token_list_all) + if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][0].item() + _context_tokens = context_tokens[0, :, :context_end_step] + else: + input_token_list_all = [ + context_and_question_tokens[0, 0, i].item() + for i in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) + for ti, t in enumerate(input_token_list_all) + if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + _context_tokens = context_and_question_tokens[0, :, :context_end_step] + + if context_end_step > 1: + is_speech_context = _context_tokens[1, :].sum().item() > 0 + if is_speech_context: + _context_tokens = self.convert_tokens_to_range( + _context_tokens, pattern=self.context_pattern + ) + _context_wav = self.decode_wav_from_codec_model(_context_tokens) + self.logger.experiment.add_audio( + "train_context_wav", _context_wav, self.global_step, self.sample_rate + ) + else: + _context_token_list = [v.item() for v in _context_tokens[0, :]] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text( + "train_context_text", _context_text, self.global_step + ) + + question_si = text_limits[0, 0].item() - virtual_tokens.shape[1] + question_ei = text_limits[0, 1].item() - virtual_tokens.shape[1] + text_si = text_limits[0, 0].item() + text_ei = text_limits[0, 1].item() + input_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Train Input Text", input_text, self.global_step) + + input_phoneme_tokens = [ + v - self.lm_vocab_size + for v in input_token_list_all[question_si:question_ei] + if v >= self.lm_vocab_size + ] + + if len(input_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(input_phoneme_tokens) + self.logger.experiment.add_text( + "Train Input Phoneme Text", phoneme_text, self.global_step + ) + + token_logits = out_logits[0] + speech_logits_list = out_logits[1] + + attention_probs_list = out_logits[2] # list of (BS, 12, out_length, in_length) + if attention_probs_list is not None: + attention_sliced_list = [] + for lidx in range(len(attention_probs_list)): + attention_probs = attention_probs_list[lidx] + for _i in range(attention_probs.shape[1]): + name = f"Attention Probs Layer {lidx} Head {_i}" + attention_to_plot = attention_probs[0, _i, :audio_len, :text_ei] + if self.plot_alignments_sliced: + attention_to_plot = attention_probs[0, _i, 0:audio_len, text_si:text_ei] + # 4 to offset "Text to Speech this" + name += " Sliced" + alignment_image = plot_alignment_to_numpy_for_speechllm( + attention_to_plot.cpu().float().numpy().T, + phoneme_ver=0 if self.plot_alignments_sliced else 1, + phoneme_seq=None if self.plot_alignments_sliced else [text_si], + ) + self.logger.experiment.add_image( + name, + alignment_image, + self.global_step, + dataformats="HWC", + ) + attention_sliced_list.append( + attention_probs[ + 0, _i, self.decoder_context_len : audio_len, text_si:text_ei + ] + ) + attention_sliced = torch.stack(attention_sliced_list) + attention_sliced = torch.mean(attention_sliced, 0) + text = None + if len(input_text) > 0: + text = self.frozen_model.tokenizer.ids_to_tokens( + [ + v + for v in input_token_list_all[question_si:question_ei] + if v < self.lm_vocab_size + ] + ) + if len(input_phoneme_tokens) > 0: + text = phoneme_text.split("|") + alignment_image_sliced = plot_alignment_to_numpy_for_speechllm( + attention_sliced.cpu().float().numpy().T, + phoneme_seq=text, + phoneme_ver=2, + vmin=0.0, + phone_offset=0, + h_offset=False, + ) + self.logger.experiment.add_image( + f"Attention Probs Average Sliced", + alignment_image_sliced, + self.global_step, + dataformats="HWC", + ) + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + token_logits_example = token_logits[:, 0, :] * 1 + speech_logits_example = speech_logits[:, 0, :, :] * 1 + first_layer_tokens = token_logits_example.argmax(dim=1) - self.speech_offset + other_layer_tokens = [] + for _i in range(speech_logits_example.shape[2]): + other_layer_tokens.append(speech_logits_example[:, :, _i].argmax(dim=1)) + + all_layer_tokens = torch.stack([first_layer_tokens] + other_layer_tokens) # (8, t) + all_layer_tokens = self.convert_tokens_to_range( + all_layer_tokens, apply_offset_correction=False + ) + # all_layer_tokens = torch.clip(all_layer_tokens, 0, 1023) + predicted_wav = self.decode_wav_from_codec_model(all_layer_tokens) + self.logger.experiment.add_audio( + "train_tf_pred_wav", predicted_wav, self.global_step, self.sample_rate + ) + + def loss_func(loss_args): + output_tensor, out_logits, curr_step = loss_args + alignment_loss = out_logits[3] + loss = self.frozen_model.loss_func(loss_mask, output_tensor) + if ( + (alignment_loss is not None) + and (curr_step > self.alignment_loss_start_step) + and (curr_step < self.alignment_loss_end_step) + ): + logging.debug(f"Adding alignment loss. cur:{curr_step} start:{self.alignment_loss_start_step}") + loss = loss + alignment_loss + reduced_loss = average_losses_across_data_parallel_group([loss]) + return loss, {'avg': reduced_loss} + + return [output_tensor, out_logits, self.global_step], loss_func + + return fwd_output_and_loss_func + + def get_forward_output_only_func(self): + """Used in inference / predict""" + + def fwd_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + _batch = [] + for x in batch: + if isinstance(x, torch.Tensor): + x = x.cuda(non_blocking=True) + elif isinstance(x, list): + if isinstance(x[0], torch.Tensor): + x = [y.cuda(non_blocking=True) for y in x] + _batch.append(x) + batch = _batch + # batch = [x.cuda(non_blocking=True) if isinstance(x, torch.Tensor) else x for x in batch] + ( + decoder_max_sequence_len, + encoder_max_sequence_len, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + speech_mask, + ) = batch + + output_logits, _, token_and_speech_logits = model( + context_and_question_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=None, + speech_mask=speech_mask, + inference=True, + inference_step=1, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + output_tensor = [output_logits, token_and_speech_logits] + + def id_func(output_tensor): + return 0, {'output_logits': output_tensor[0], 'token_and_speech_logits': output_tensor[1]} + + return output_tensor, id_func + + return fwd_output_only_func + + def backward(self, *args, **kwargs): + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. + No need to call it here. + """ + return + + def optimizer_zero_grad(self, *args, **kwargs): + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. + """ + return + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + When using pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.frozen_model.enc_dec_model.set_input_tensor(input_tensor) + + def on_train_epoch_start(self) -> None: + gbs = self.cfg.global_batch_size + mbs = self.cfg.micro_batch_size + self._reconfigure_batch_sizes(gbs, mbs) + return super().on_train_epoch_start() + + def on_validation_epoch_start(self) -> None: + gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) + mbs = self.cfg.get('validation_micro_batch_size', self.cfg.micro_batch_size) + self._reconfigure_batch_sizes(gbs, mbs) + return super().on_validation_epoch_start() + + def training_step(self, dataloader_iter, batch_idx): + self._optimizer.zero_grad() + batch = next(dataloader_iter) + + # apply text classifier-free guidance by replacing input question tokens with [UNK]. + if self.train_text_cfg_prob > 0.0: + if self._rng.random() < self.train_text_cfg_prob: + logging.info(f"Text Classifier-Free Guidance is triggered for the {batch_idx}-th batch.") + + # temporally disable computing CTC alignment loss. + if self.use_alignment_loss: + self.frozen_model.enc_dec_model.use_alignment_loss = False + + # make cross-attention prior to None to remove the prior. + batch[11] = None + + # replace question token IDs with [UNK]'s id. No speech offset for Phoneme's [UNK]. Same op as train. + # instruction token IDs are bpe token IDs directly obtained from self.tokenizer without any offset. + # question token IDs are phoneme and grapheme token IDs and are offset by self.lm_vocab_size + # if under "Phoneme TTS" instruction, so existing no overlaps between instruction and question token IDs. + # question token IDs are bpe token IDs without any offset + # if under "Text to speech this" instruction, so existing overlaps between instruction and question token IDs. + context_and_question_tokens = batch[ + 1 + ] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + text_limits = batch[12] + virtual_tokens = batch[0] + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type=multi_transformers. + context_tokens, question_tokens = context_and_question_tokens + question_tokens_unconditioned = question_tokens.clone() + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + batch[1] = [context_tokens, question_tokens_unconditioned] + else: + context_and_question_tokens_unconditioned = ( + context_and_question_tokens.clone() + ) # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + batch[1] = context_and_question_tokens_unconditioned + + del question_limits, question_start, question_end, time_range, question_mask + else: + # recover to original alignment loss config. + self.frozen_model.enc_dec_model.use_alignment_loss = self.use_alignment_loss + + # apply audio context classifier-free guidance by replacing audio codec with [UNK] + if self.train_audio_cfg_prob > 0.0: + if self._rng.random() < self.train_audio_cfg_prob: + logging.info(f"Audio Classifier-Free Guidance is triggered for the {batch_idx}-th batch.") + + context_and_question_tokens = batch[ + 1 + ] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type=multi_transformers. + context_tokens, question_tokens = context_and_question_tokens + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. + batch[1] = [context_tokens_unconditioned, question_tokens] + else: + # dec_input + dec_input = batch[3] + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. + batch[3] = dec_input_unconditioned + + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=False) + self.allreduce_gradients() + + ## logging + # we can only log on one rank if it is rank zero so we broadcast from last rank + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + torch.distributed.broadcast(loss_mean, get_last_rank()) + + if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"): + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, batch_size=1) + + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) + return loss_mean + + def get_predictions(self, input_ids, enc_mask, encoder_input, labels): + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=input_ids, + enc_mask=enc_mask, + num_tokens_to_generate=self.decoder_seq_length, + encoder_input=encoder_input, + bos_id=( + self.tokenizer.pad_id if self.cfg.data.get('decoder_starts_with_pad', False) else self.tokenizer.bos_id + ), + ) + # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. + preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) + labels_text = MegatronT5SFTModel.ids_to_text(labels, self.tokenizer) + input_text = MegatronT5SFTModel.ids_to_text(input_ids, self.tokenizer) + return { + 'predicted_token_ids': preds_text, + 'labels': labels_text, + 'enc_inputs': input_text, + } + + def get_embeddings(self, tokens, taskname_ids, inference=False): + out = None + if tokens.dim() > 2: + for i in range(tokens.size()[1]): # for 8 channels + if i == 0: + # Embed first layer using word embeddings + out = self.embed_input(tokens[:, i, :], taskname_ids, inference) # (B, T, D) + else: + # Embed other layers using speech embeddings + cur = self.frozen_model.enc_dec_model.speech_tokens_embeddings[i - 1](tokens[:, i, :]) + # do not add embeddings of zero tokens of other channels (except the first channel) + non_zero_flag = tokens[:, i, :] != 0 # (B, T) + cur = cur * non_zero_flag.unsqueeze(2) + out = out + cur + else: + out = self.embed_input(tokens, taskname_ids, inference) + return out + + def get_embeddings_and_combine(self, token_list, taskname_ids, inference): + embedding_list = [] + for tokens in token_list: + embedding_list.append(self.get_embeddings(tokens, taskname_ids, inference)) + return torch.cat(embedding_list, dim=1) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, + _, + _, + ) = batch + # loss_mask (b, t) + # does not use dataloader_iter due to device placement issues arising from PTL + + mode = self.training + self.eval() + gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) + self._reconfigure_and_process_inference_batch(virtual_tokens.size(0), gbs) + + loss_mean = self.fwd_bwd_step( + itertools.chain([batch]), batch_idx, forward_only=True + ) # comment this out and add custom forward function to calculate WER + # # logging.info (f'loss_mean {loss_mean}') + + if batch_idx == 0 and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = True + self.predict_step_outputs = [] + # log_scalars=False avoids logging scalar TTS metrics in the predict_step + # Images, audio and texts will still be logged + self.predict_step(batch=batch, batch_idx=batch_idx, log_scalars=False, global_step=self.global_step) + for inf_key in self.predict_step_outputs[0]: + if self.predict_step_outputs[0][inf_key] is not None: + self.logger.experiment.add_scalar( + f'Val_{inf_key}', self.predict_step_outputs[0][inf_key], self.global_step + ) + + labels_original = labels.clone() # (b, 8, t) + + _cross_attention_prior = cross_attention_prior + if isinstance(context_and_question_tokens, list): + _cross_attention_prior = [None, cross_attention_prior] + + output_loss, _, output_logits = self.forward( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=labels, + speech_mask=speech_mask, + cross_attention_prior=_cross_attention_prior, + text_limits=text_limits, + inference=False, + ) + + if batch_idx == 0 and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = False + with torch.cuda.amp.autocast(enabled=False): + if torch.count_nonzero(speech_mask) == 0: + text_labels = labels[:, 0, :] # [B, 8, T] -> [B, T] + token_logits = output_logits[0] * 1 # [T, B, V] + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + token_logits = token_logits.argmax(dim=2) # [T, B] + token_logits = token_logits.t() # [B, T] + score = 0 + for i in range(text_labels.size()[0]): + r = text_labels[i].long() + nzm = r != 0 + r = r.tolist() + h = token_logits[i].long() * nzm + h = h.tolist() + score += editdistance.eval(r, h) + score /= text_labels.size()[0] + logging.info(f"wer score : {score}") + self.logger.experiment.add_scalar('WER', score, self.global_step) + else: + audio_len = self.decoder_context_len + (labels[0][0][self.decoder_context_len :] != 0).sum().item() + labels_to_1024 = self.convert_tokens_to_range(labels[0, :, 0:audio_len]) + label_wav = self.decode_wav_from_codec_model(labels_to_1024) + dec_input_to_1024 = self.convert_tokens_to_range(dec_input[0, :, 0:audio_len]) + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024) + self.logger.experiment.add_audio("val_label_wav", label_wav, self.global_step, self.sample_rate) + self.logger.experiment.add_audio( + "val_dec_input_wav", dec_input_wav, self.global_step, self.sample_rate + ) + + if isinstance(context_and_question_tokens, list): + context_tokens = context_and_question_tokens[0] + question_tokens = context_and_question_tokens[1] + input_token_list_all = [ + question_tokens[0, 0, i].item() for i in range(question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list_all) if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][0].item() + _context_tokens = context_tokens[0, :, :context_end_step] + + else: + input_token_list_all = [ + context_and_question_tokens[0, 0, i].item() + for i in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list_all) if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + _context_tokens = context_and_question_tokens[0, :, :context_end_step] + if context_end_step > 1: + is_speech_context = _context_tokens[1, :].sum().item() > 0 + if is_speech_context: + _context_tokens = self.convert_tokens_to_range( + _context_tokens, pattern=self.context_pattern + ) + _context_wav = self.decode_wav_from_codec_model(_context_tokens) + self.logger.experiment.add_audio( + "val_context_wav", _context_wav, self.global_step, self.sample_rate + ) + else: + _context_token_list = [v.item() for v in _context_tokens[0, :]] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("val_context_text", _context_text, self.global_step) + + question_si = text_limits[0, 0].item() - virtual_tokens.shape[1] + question_ei = text_limits[0, 1].item() - virtual_tokens.shape[1] + + text_si = text_limits[0, 0].item() + text_ei = text_limits[0, 1].item() + + input_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Val Input Text", input_text, self.global_step) + + input_phoneme_tokens = [ + v - self.lm_vocab_size + for v in input_token_list_all[question_si:question_ei] + if v >= self.lm_vocab_size + ] + if len(input_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(input_phoneme_tokens) + self.logger.experiment.add_text("Val Input Phoneme Text", phoneme_text, self.global_step) + + token_logits = output_logits[0] + speech_logits_list = output_logits[1] + + # if self.trainer.global_step % 500 == 0: + attention_probs_list = output_logits[2] # list of (BS, 12, out_length, in_length) + if attention_probs_list is not None: + attention_sliced_list = [] + for lidx in range(len(attention_probs_list)): + attention_probs = attention_probs_list[lidx] + for _i in range(attention_probs.shape[1]): + attention_sliced_list.append( + attention_probs[0, _i, self.decoder_context_len : audio_len, text_si:text_ei] + ) + attention_sliced = torch.stack(attention_sliced_list) + attention_sliced = torch.mean(attention_sliced, 0) + text = None + if len(input_text) > 0: + text = self.frozen_model.tokenizer.ids_to_tokens( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + if len(input_phoneme_tokens) > 0: + text = phoneme_text.split("|") + alignment_image_sliced = plot_alignment_to_numpy_for_speechllm( + attention_sliced.cpu().float().numpy().T, + phoneme_seq=text, + phoneme_ver=2, + vmin=0.0, + phone_offset=0, + h_offset=False, + ) + self.logger.experiment.add_image( + f"Val Attention Probs Average Sliced", + alignment_image_sliced, + self.global_step, + dataformats="HWC", + ) + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + token_logits_example = token_logits[:, 0, :] * 1 + speech_logits_example = speech_logits[:, 0, :, :] * 1 + first_layer_tokens = token_logits_example.argmax(dim=1) - self.speech_offset + other_layer_tokens = [] + for _i in range(speech_logits_example.shape[2]): + other_layer_tokens.append(speech_logits_example[:, :, _i].argmax(dim=1)) + + all_layer_tokens = torch.stack([first_layer_tokens] + other_layer_tokens) # (8, t) + all_layer_tokens = self.convert_tokens_to_range(all_layer_tokens, apply_offset_correction=False) + all_layer_tokens = torch.clip(all_layer_tokens, 0, self.speech_codebook_size - 1) + predicted_wav = self.decode_wav_from_codec_model(all_layer_tokens) + self.logger.experiment.add_audio( + "val_tf_pred_wav", predicted_wav, self.global_step, self.sample_rate + ) + + first_layer_logits = output_logits[0] + speech_logits_list = output_logits[1] + + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + first_layer_logits = tensor_parallel.gather_from_tensor_model_parallel_region(first_layer_logits) + if torch.count_nonzero(speech_mask) > 0: + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + first_layer_preds = first_layer_logits.argmax(dim=2) # (t,bs) + first_layer_preds = first_layer_preds.transpose(0, 1) # (bs,t) + labels_first_layer = labels_original[:, 0, :] # (bs,t) + correct_predictions = first_layer_preds == labels_first_layer # (bs,t) + correct_predictions = correct_predictions * loss_mask # (bs,t) + total_correct_predictions = torch.sum(correct_predictions) + total_predictions = torch.sum(loss_mask) + first_layer_accuracy = total_correct_predictions / total_predictions + first_layer_loss = torch.nn.functional.cross_entropy( + first_layer_logits.permute(1, 2, 0), labels_first_layer, reduction='none' + ) # (bs,t) + first_layer_loss = torch.sum(first_layer_loss * loss_mask) / total_predictions + + metrics = { + 'loss': loss_mean, + 'first_layer_accuracy': first_layer_accuracy, + 'first_layer_loss': first_layer_loss, + } + loss_total = first_layer_loss + for i in range(self.num_speech_codebooks - 1): + if torch.count_nonzero(speech_mask) > 0: + speech_logits_i = speech_logits[:, :, :, i] + speech_preds_i = speech_logits_i.argmax(dim=2) # (t,bs) + speech_preds_i = speech_preds_i.transpose(0, 1) # (bs,t) + labels_i = labels_original[:, i + 1, :] # (bs,t) + correct_predictions_i = speech_preds_i == labels_i # (bs,t) + correct_predictions_i = correct_predictions_i * loss_mask * speech_mask # (bs,t) + total_correct_predictions_i = torch.sum(correct_predictions_i) + total_predictions_i = torch.sum(loss_mask * speech_mask) + speech_accuracy_i = total_correct_predictions_i / total_predictions_i + loss_i = torch.nn.functional.cross_entropy( + speech_logits_i.permute(1, 2, 0), labels_i, reduction='none' + ) # (bs,t) + loss_i = torch.sum(loss_i * loss_mask * speech_mask) / total_predictions_i + else: + speech_accuracy_i = torch.tensor(0.0) + loss_i = torch.tensor(0.0) + metrics[f'speech_accuracy_{i+1}'] = speech_accuracy_i + metrics[f'speech_loss_{i+1}'] = loss_i + loss_total += loss_i + + metrics['loss_total_check'] = loss_total + self.validation_step_outputs.append(metrics) + self.train(mode=mode) + self.frozen_model.train() + return metrics['loss'] + + def on_validation_epoch_end(self): + outputs = self.validation_step_outputs + if self.cfg.get('pipeline_model_parallel_size', 1) > 1: + if parallel_state.is_pipeline_last_stage(): + # only the last pipeline parallel stages return loss + averaged_loss = torch.stack([item['loss'] for item in outputs]).mean() + averaged_loss_total_check = torch.stack([item['loss_total_check'] for item in outputs]).mean() + averaged_first_layer_accuracy = torch.stack([item['first_layer_accuracy'] for item in outputs]).mean() + + self.log( + 'val_loss_total_check', + averaged_loss_total_check, + prog_bar=False, + rank_zero_only=True, + batch_size=1, + ) + self.log( + 'val_first_layer_accuracy', + averaged_first_layer_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + logging.info(f'Validation first_layer_accuracy: {averaged_first_layer_accuracy}') + logging.info(f'Validation loss_total_check: {averaged_loss_total_check}') + + for i in range(1, self.num_speech_codebooks): + averaged_speech_accuracy = torch.stack([item[f'speech_accuracy_{i}'] for item in outputs]).mean() + averaged_speech_loss = torch.stack([item[f'speech_loss_{i}'] for item in outputs]).mean() + self.log( + f'val_speech_accuracy_{i}', + averaged_speech_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + self.log( + f'val_speech_loss_{i}', averaged_speech_loss, prog_bar=True, rank_zero_only=True, batch_size=1 + ) + logging.info(f'Validation speech_accuracy_{i}: {averaged_speech_accuracy}') + logging.info(f'Validation speech_loss_{i}: {averaged_speech_loss}') + else: + averaged_loss = torch.tensor(0.0).cuda() + + # we can only log on one rank if it is rank zero so we broadcast from last rank + torch.distributed.broadcast(averaged_loss, get_last_rank()) + + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) + logging.info(f'Validation loss: {averaged_loss}') + + else: + if len(outputs) > 0: + averaged_loss = torch.stack([item['loss'] for item in outputs]).mean() + averaged_loss_total_check = torch.stack([item['loss_total_check'] for item in outputs]).mean() + logging.info(f'Validation loss: {averaged_loss}') + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log( + 'val_loss_total_check', + averaged_loss_total_check, + prog_bar=False, + rank_zero_only=True, + batch_size=1, + ) + + averaged_first_layer_accuracy = torch.stack([item['first_layer_accuracy'] for item in outputs]).mean() + logging.info(f'Validation first_layer_accuracy: {averaged_first_layer_accuracy}') + self.log( + 'val_first_layer_accuracy', + averaged_first_layer_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + + for i in range(1, self.num_speech_codebooks): + averaged_speech_accuracy = torch.stack([item[f'speech_accuracy_{i}'] for item in outputs]).mean() + averaged_speech_loss = torch.stack([item[f'speech_loss_{i}'] for item in outputs]).mean() + logging.info(f'Validation speech_accuracy_{i}: {averaged_speech_accuracy}') + logging.info(f'Validation speech_loss_{i}: {averaged_speech_loss}') + self.log( + f'val_speech_accuracy_{i}', + averaged_speech_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + self.log( + f'val_speech_loss_{i}', averaged_speech_loss, prog_bar=True, rank_zero_only=True, batch_size=1 + ) + + if self.cfg.get("report_validation_metric", False): + gather_results = [None for _ in range(parallel_state.get_data_parallel_world_size())] + + all_preds = list(itertools.chain(*[item['predicted_token_ids'] for item in outputs])) + all_labels = list(itertools.chain(*[item['labels'] for item in outputs])) + all_inputs = list(itertools.chain(*[item['enc_inputs'] for item in outputs])) + + assert len(all_preds) == len(all_labels) + assert len(all_preds) == len(all_inputs) + + # Gather inputs, preds, labels from all workers + torch.distributed.all_gather_object( + gather_results, + [(input, pred, label) for (input, pred, label) in zip(all_inputs, all_preds, all_labels)], + group=parallel_state.get_data_parallel_group(), + ) + + # Deduplicate sentences that may have been distributed across multiple data parallel ranks. + if parallel_state.get_data_parallel_rank() == 0: + + gather_results_dedup = list(set(itertools.chain(*gather_results))) + + val_metric_dict = self.validation_metric.get_score( + [i[2] for i in gather_results_dedup], + [i[1] for i in gather_results_dedup], + ) + + for metric, val in val_metric_dict.items(): + logging.info(f'Validation {metric}: {val}') + val_metric = list(val_metric_dict.items())[0][1] + metric_name = list(val_metric_dict.items())[0][0] + else: + val_metric = torch.tensor(0.0).cuda() + metric_name = '' + + self.log(f'val_{metric_name}', val_metric, prog_bar=True, rank_zero_only=True, batch_size=1) + + gbs = self.cfg.global_batch_size + mbs = self.cfg.micro_batch_size + self._reconfigure_batch_sizes(gbs, mbs) + self.validation_step_outputs.clear() + + def test_step(self, batch, batch_idx): + result = self.predict_step(batch, batch_idx) + return result + + def on_test_epoch_end(self): + """ + This might still be broken for lightning 2.0. to fix: see + https://github.com/NVIDIA/NeMo/blob/9bdf4d12276ee8f95a340cf2f7f340e9b5b74a7e/docs/source/starthere/migration-guide.rst + """ + outputs = self.predict_step_outputs + average_metrics = {} + for output in outputs: + for key in output: + if key not in average_metrics: + average_metrics[key] = [] + if isinstance(output[key], torch.Tensor): + average_metrics[key].append(output[key].item()) + elif output[key] is None: + continue + else: + average_metrics[key].append(output[key]) + + for key in average_metrics: + average_metrics[key] = np.mean(average_metrics[key]).item() + logging.info(f'Test {key}: {average_metrics[key]}') + self.log(f'test_{key}', average_metrics[key], prog_bar=True, rank_zero_only=True, batch_size=1) + self.logger.experiment.add_scalar(f'Inf Cumulative {key}', average_metrics[key], 0) + + # save average metrics into json file + with open(os.path.join(self.logger.log_dir, 'output_metrics.json'), 'w') as f: + json.dump(average_metrics, f) + + def build_virtual_prompt_dataset( + self, dataset_paths, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory + ): + dataset = T5SpeechLMDataset( + datasets=dataset_paths, + tokenizer=self.tokenizer, + sample_rate=self.cfg.data.get('sample_rate', 24000), + virtual_prompt_source=self.virtual_prompt_source, + task_templates=self.task_templates, + pseudo_tokens=self.pseudo_tokens, + pad_token_id=self.pad_token_id, + max_seq_length=self.cfg.data.get('max_seq_length', self.frozen_model.cfg.max_position_embeddings), + min_seq_length=self.cfg.data.get('min_seq_length', 1), + add_bos=self.cfg.data.get('add_bos', False), + add_eos=self.cfg.data.get('add_eos', True), + decoder_starts_with_pad=self.cfg.data.get('decoder_starts_with_pad', False), + add_eos_to_decoder_output=self.cfg.data.get('add_eos_to_decoder_output', True), + add_sentinel_to_input=self.cfg.data.get('add_sentinel_to_input', True), + ul2_prompt_token=self.cfg.data.get('ul2_prompt_token', None), + for_train=for_train, + segment_max_duration=self.cfg.data.get('segment_max_duration', None), + trim=self.cfg.data.get('trim', None), + trim_ref=self.cfg.data.get('trim_ref', None), + trim_top_db=self.cfg.data.get('trim_top_db', None), + trim_frame_length=self.cfg.data.get('trim_frame_length', None), + trim_hop_length=self.cfg.data.get('trim_hop_length', None), + pad_multiple=self.cfg.data.get('pad_multiple', 1), + pitch_augment=self.cfg.data.get('pitch_augment', None), + sup_data_path=self.cfg.data.get('sup_data_path', None), + codec_folder=self.cfg.data.get('codec_folder', None), + speech_offset=self.cfg.data.get('speech_offset', None), + train_task=self.cfg.data.get('train_task', "tts"), + seq_pattern=self.cfg.get('seq_pattern', 'delay_parallel'), + use_attention_prior=self.cfg.data.get('use_attention_prior', False), + attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.0), + cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 0.0), + lm_vocab_size=self.lm_vocab_size, + num_speech_codebooks=self.num_speech_codebooks, + codebook_fps=self.cfg.data.get('codebook_fps', 86), + add_special_tokens_to_only_first_codebook=self.cfg.data.get( + 'add_special_tokens_to_only_first_codebook', False + ), + context_pattern=self.cfg.data.get('context_pattern', 'parallel'), + context_duration_min=self.cfg.data.get('context_duration_min', 3.0), + context_duration_max=self.cfg.data.get('context_duration_max', 5.0), + g2p=self.cfg.data.get('g2p', None), + skip_datasets=self.cfg.data.get('skip_datasets', []), + english_only_model=self.cfg.get('english_only_model', False), + use_ipa=self.cfg.data.get('use_ipa', False), + context_conditioning=self.cfg.get('context_conditioning', "decoder"), + use_beta_binomial_interpolator=self.cfg.get('use_beta_binomial_interpolator', False), + context_slice_method=self.cfg.data.get('context_slice_method', 'random'), + phoneme_probability=self.cfg.data.get('phoneme_probability', 0.5), + encoder_type=self.cfg.data.get('encoder_type', 'single_transformer'), + ) + + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=self.cfg.seed + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + sampler=sampler, + batch_size=batch_size // world_size, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=( + True if num_workers > 0 else False + ), # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + ) + logging.info(f'build success: {len(dataloader)} {dataset_paths}') + if self.phoneme_tokenizer is None: + self.phoneme_tokenizer = dataset.phoneme_tokenizer + return dataset, dataloader + + def build_virtual_prompt_tarred_dataset( + self, dataset_paths, audio_path, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory + ): + dataset = T5SpeechLMTarredDataset( + audio_tar_filepaths=audio_path, + manifest_filepath=dataset_paths, + tokenizer=self.tokenizer, + sample_rate=self.cfg.data.get('sample_rate', 24000), + virtual_prompt_source=self.virtual_prompt_source, + task_templates=self.task_templates, + pseudo_tokens=self.pseudo_tokens, + pad_token_id=self.pad_token_id, + max_seq_length=self.cfg.data.get('max_seq_length', self.frozen_model.cfg.max_position_embeddings), + min_seq_length=self.cfg.data.get('min_seq_length', 1), + shuffle_n=shuffle, + add_bos=self.cfg.data.get('add_bos', False), + add_eos=self.cfg.data.get('add_eos', True), + decoder_starts_with_pad=self.cfg.data.get('decoder_starts_with_pad', False), + add_eos_to_decoder_output=self.cfg.data.get('add_eos_to_decoder_output', True), + add_sentinel_to_input=self.cfg.data.get('add_sentinel_to_input', True), + ul2_prompt_token=self.cfg.data.get('ul2_prompt_token', None), + for_train=for_train, + segment_max_duration=self.cfg.data.get('segment_max_duration', None), + trim=self.cfg.data.get('trim', None), + trim_ref=self.cfg.data.get('trim_ref', None), + trim_top_db=self.cfg.data.get('trim_top_db', None), + trim_frame_length=self.cfg.data.get('trim_frame_length', None), + trim_hop_length=self.cfg.data.get('trim_hop_length', None), + pad_multiple=self.cfg.data.get('pad_multiple', 1), + pitch_augment=self.cfg.data.get('pitch_augment', None), + speech_offset=self.cfg.data.get('speech_offset', None), + train_task=self.cfg.data.get('train_task', "tts"), + seq_pattern=self.cfg.get('seq_pattern', 'delay_parallel'), + use_attention_prior=self.cfg.data.get('use_attention_prior', False), + attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.0), + cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 0.0), + lm_vocab_size=self.lm_vocab_size, + num_speech_codebooks=self.num_speech_codebooks, + ) + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + batch_size=batch_size // world_size, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=( + True if num_workers > 0 else False + ), # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + ) + logging.info(f'build success: {len(dataloader)} {dataset_paths}') + + return dataset, dataloader + + def process_text(self, input_text): + """ + Normalizes text for CER/WER calculation. + Taken from hallucination_eval.py + """ + # Convert text to lowercase + lower_case_text = input_text.lower() + + # Remove commas from text + no_comma_text = lower_case_text.replace(",", "") + + # Replace "-" with spaces + no_dash_text = no_comma_text.replace("-", " ") + + # Replace double spaces with single space + single_space_text = " ".join(no_dash_text.split()) + + single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) + + return single_space_text + + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_scalars=True, global_step=None + ) -> Any: + + with torch.no_grad(): + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input_raw, + dec_input_mask_raw, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, # [start of question token, question token len) in [0, enc_mask.size(1)) + lang, + question_texts, + ) = batch + + batch_size = virtual_tokens.size(0) + dec_input = ( + dec_input_raw * 1 + ) # (B, 8, T) # TODO @xueyang: apply clone() method bypasses this unnecessary computation. + dec_input_mask = dec_input_mask_raw * 1 # (B, T) + dec_input_mask[:, :] = 1 # Does not really matter + output_token_list = [] + + end_indices = {} + # pad dec_input (B, 8, T) to 1000 timesteps + max_inference_timesteps = self.cfg.get('max_inference_timesteps', 2000) + # TODO @xueyang: potential bug when max_inference_timesteps < dec_input.shape[2], then dec_input is clipped. + dec_input = torch.nn.functional.pad(dec_input, (0, max_inference_timesteps - dec_input.shape[2]), value=0) + dec_input[:, :, self.decoder_context_len + 1 :].zero_() + # TODO @xueyang: why not just declare torch.ones(dec_input_raw.size(0), max_inference_timesteps)? + dec_input_mask = torch.nn.functional.pad( + dec_input_mask, (0, max_inference_timesteps - dec_input_mask.shape[1]), value=1 + ) + + if self.inference_apply_text_cfg and self.inference_apply_audio_cfg: + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type = "multi_transformers". + context_tokens, question_tokens = context_and_question_tokens + + # text + question_tokens_unconditioned = question_tokens.clone() + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # audio + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = self.tokenizer.unk_id + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens_unconditioned), dim=0), + torch.cat((question_tokens, question_tokens_unconditioned), dim=0), + ] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + + # text + context_and_question_tokens_unconditioned = context_and_question_tokens.clone() + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # audio + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0 + ) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input_unconditioned), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + + # clean up useless variables. + del question_limits, question_start, question_end, time_range, question_mask + elif self.inference_apply_text_cfg: + # replace question token IDs with [UNK]'s id. No speech offset for Phoneme's [UNK]. Same op as train. + # instruction token IDs are bpe token IDs directly obtained from self.tokenizer without any offset. + # question token IDs are phoneme and grapheme token IDs and are offset by self.lm_vocab_size + # if under "Phoneme TTS" instruction, so exising no overlaps between instruction and question token IDs. + # question token IDs are bpe token IDs without any offset + # if under "Text to speech this" instruction, so existing overlaps between instruction and question token IDs. + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type = "multi_transformers". + context_tokens, question_tokens = context_and_question_tokens + question_tokens_unconditioned = question_tokens.clone() + + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens), dim=0), + torch.cat((question_tokens, question_tokens_unconditioned), dim=0), + ] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + context_and_question_tokens_unconditioned = context_and_question_tokens.clone() + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0 + ) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + + # clean up useless variables. + del question_limits, question_start, question_end, time_range, question_mask + elif self.inference_apply_audio_cfg: + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance( + context_and_question_tokens, list + ): # indicate that self.encoder_type = "multi_transformers" + context_tokens, question_tokens = context_and_question_tokens + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens_unconditioned), dim=0), + torch.cat((question_tokens, question_tokens), dim=0), + ] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens), dim=0 + ) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input_unconditioned), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + else: + logging.debug( + f"Neither text or audio cfg logits are applied:" + f" self.inference_apply_text_cfg={self.inference_apply_text_cfg}," + f" self.inference_apply_audio_cfg={self.inference_apply_audio_cfg}" + ) + + end_inference_loop_at = None + fwd_bwd_function = get_forward_backward_func() + encoder_output = None + attention_probs_all = [] + start_time = time.time() + for t in range(self.decoder_context_len + 1, dec_input.shape[2] - 1): + # Start at 0 if encoder context, else context_len + if t % 100 == 0: + logging.info("Timestep {}".format(t)) + if t == end_inference_loop_at: + print("All ends detected") + break + + if isinstance(enc_mask, list): + encoder_max_sequence_len = [e.size(1) for e in enc_mask] + else: + encoder_max_sequence_len = enc_mask.size(1) + + # if context_condition is decoder, then t starts at [PAD] token represented as [0] * 8. + # if context_condition is encoder, then t starts at [CLS]. + if t == self.decoder_context_len + 1: + # Run first step manually + output_logits, _, token_and_speech_logits = self.forward( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input[ + :, :, : t + 1 + ], # tensors representing [CLS] + context audio tokens + [PAD] if context_condition is decoder, otherwise, tensors representing [CLS]. + dec_input_mask[:, : t + 1], # doesn't matter because of all ones. + position_ids, + taskname_ids, + labels=None, + speech_mask=speech_mask, + inference=True, + inference_step=0, + decoder_max_sequence_len=max_inference_timesteps, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + encoder_output = token_and_speech_logits[-1] + + if isinstance(encoder_output, list): + encoder_output = [e.transpose(0, 1) for e in encoder_output] + else: + encoder_output = encoder_output.transpose(0, 1) + + else: + # Prepare batch + batch = [ + max_inference_timesteps, + encoder_max_sequence_len, + encoder_output, + enc_mask, + dec_input[:, :, : t + 1], + dec_input_mask[:, : t + 1], + position_ids, + taskname_ids, + speech_mask, + ] + + output_tensor = fwd_bwd_function( + forward_step_func=self.get_forward_output_only_func(), + data_iterator=iter( + [ + batch, + ] + ), + model=[self], + num_microbatches=get_num_microbatches(), + forward_only=True, + seq_length=t, + micro_batch_size=dec_input.shape[0], + ) + output_logits = output_tensor[0]['output_logits'] # (B, T, V, 8) or (2B, T, V, 8) + token_and_speech_logits = output_tensor[0]['token_and_speech_logits'] + + # when return_all_crossattention is False, attention_probs is None. + if self.frozen_model.enc_dec_model.return_all_crossattention_probs: + attention_probs = token_and_speech_logits[2] + attention_probs_mean = torch.stack(attention_probs).mean(dim=0) # B, 12, 1, enc_timesteps + attention_probs_all.append(attention_probs_mean) + + if self.inference_apply_text_cfg or self.inference_apply_audio_cfg: + # interpolate conditioned and unconditioned logits + token_logits = ( + self.inference_cfg_interpolation_scale * token_and_speech_logits[0][:batch_size] + + (1 - self.inference_cfg_interpolation_scale) * token_and_speech_logits[0][batch_size:] + ) + output_speech_logits = ( + self.inference_cfg_interpolation_scale * output_logits[:batch_size] + + (1 - self.inference_cfg_interpolation_scale) * output_logits[batch_size:] + ) + else: + token_logits = token_and_speech_logits[0] # (B, T, V) + output_speech_logits = output_logits + + token_logits_currtimestep = token_logits[:, -1, :] # (B, V) + token_preds = token_logits_currtimestep.argmax(dim=1) # (B,) + + if torch.count_nonzero(speech_mask) > 0: + output_logits_currtimestep = ( + output_speech_logits[:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) + ) # (B*8, V) + output_logits_currtimestep_conditioned = ( + output_logits[:batch_size][:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) + ) + output_logits_currtimestep_unconditioned = ( + output_logits[batch_size:][:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) + ) + else: + output_logits_currtimestep = token_logits_currtimestep # (B, V) + output_logits_currtimestep_conditioned = token_logits_currtimestep + output_logits_currtimestep_unconditioned = token_logits_currtimestep + + top_k = self.cfg.get('top_k', 80) + + # (B*8, 80) or (B, 80) + output_logits_currtimestep_topk = torch.topk(output_logits_currtimestep, top_k, dim=1)[0] + + # find indices which are not top k + indices_to_remove = output_logits_currtimestep < output_logits_currtimestep_topk[:, -1].unsqueeze(1) + # (B*8, 1024) or (B, 1024) + + if self.inference_apply_cfg_filter: + output_logits_currtimestep_rescored = output_logits_currtimestep_conditioned.clone() + else: + output_logits_currtimestep_rescored = output_logits_currtimestep.clone() + + output_logits_currtimestep_rescored[indices_to_remove] = -float('Inf') + + # logits interpolation between conditioned and unconditioned logits. + if ( + self.inference_apply_text_cfg or self.inference_apply_audio_cfg + ) and self.inference_apply_cfg_filter: + output_logits_currtimestep_rescored = ( + self.inference_cfg_filter_interpolation_scale * output_logits_currtimestep_rescored + + (1 - self.inference_cfg_filter_interpolation_scale) + * output_logits_currtimestep_unconditioned + ) + + temperature = self.cfg.get('temperature', 0.85) # Set temp 0.01 for greedy decoding + output_logits_currtimestep_rescored = output_logits_currtimestep_rescored / temperature + output_logits_currtimestep_rescored = torch.nn.functional.softmax( + output_logits_currtimestep_rescored, dim=1 + ) + + output_tokens_curr_timestep = torch.multinomial( + output_logits_currtimestep_rescored, num_samples=1 + ) # (B*8, 1) + + if torch.count_nonzero(speech_mask) > 0: + # Convert back to (B, 8) + output_tokens_curr_timestep = output_tokens_curr_timestep.view( + batch_size, self.num_speech_codebooks + ) + + for _b in range(token_preds.shape[0]): + if t > self.decoder_context_len + 10 and token_preds[_b] == self.tokenizer.eos_id: + if _b not in end_indices: + logging.info("End detected for item {}".format(_b) + " at timestep {}".format(t)) + end_indices[_b] = t + if len(end_indices) == token_preds.shape[0]: + end_inference_loop_at = t + self.num_speech_codebooks + + output_token_list.append(output_tokens_curr_timestep) + + # duplicate to 2b dim as input for the next iteration if enabling cfg. + if self.inference_apply_text_cfg or self.inference_apply_audio_cfg: + output_tokens_curr_timestep = torch.cat( + (output_tokens_curr_timestep, output_tokens_curr_timestep), dim=0 + ) + + if torch.count_nonzero(speech_mask) > 0: + dec_input_next_timestep = output_tokens_curr_timestep * 1 # (B,8) + dec_input_next_timestep[:, 0] = ( + dec_input_next_timestep[:, 0] + self.speech_offset + ) # add offset to first codebook + dec_input[:, :, t + 1] = dec_input_next_timestep * 1 + else: + dec_input[:, 0, t + 1] = output_tokens_curr_timestep.squeeze(1) + + # end of for loop + output_tokens_combined = torch.stack(output_token_list) # (T, B, 8) if speech else (T, B) + if torch.count_nonzero(speech_mask) > 0: + output_tokens_combined = output_tokens_combined.permute(1, 2, 0) # (B, 8, T) + else: + output_tokens_combined = output_tokens_combined.squeeze(2) + output_tokens_combined = output_tokens_combined.permute(1, 0) # (B, T) + + # consider only autoregressive time, disconsider loading eval models for RTF time + total_process_time = time.time() - start_time + + # Layerwise token error rate + ter_dict = {} + for i in range(self.num_speech_codebooks): + ter_dict[i] = {'hypothesis': [], 'gt': []} + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if 'nemo_sv_model' not in self.additional_models: + nemo_sv_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + nemo_sv_model = nemo_sv_model.to(device) + nemo_sv_model.encoder.disable_torch_distributed = True # For multi-gpu training validation + nemo_sv_model.eval() + self.additional_models['nemo_sv_model'] = nemo_sv_model + logging.info(f"Loaded SV Model: {nemo_sv_model}") + else: + nemo_sv_model = self.additional_models['nemo_sv_model'] + + if 'asr_model' not in self.additional_models: + asr_model = self.cfg.get("asr_model_name", "stt_multilingual_fastconformer_hybrid_large_pc_blend_eu") + + if "hybrid" in asr_model: + model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel + else: + model = nemo_asr.models.EncDecRNNTBPEModel + asr_model = model.from_pretrained(model_name=asr_model) + asr_model = asr_model.to(device) + asr_model.encoder.disable_torch_distributed = True # For multi-gpu training validation + asr_model.eval() + self.additional_models['asr_model'] = asr_model + logging.info(f"Loaded ASR Model: {asr_model}") + else: + asr_model = self.additional_models['asr_model'] + + asr_model_zh = None + if Lang.zh.value in lang: + if 'asr_model_zh' not in self.additional_models: + asr_model_zh = nemo_asr.models.EncDecRNNTModel.from_pretrained( + model_name="stt_zh_conformer_transducer_large" + ) + asr_model_zh = asr_model_zh.to(device) + asr_model_zh.eval() + self.additional_models['asr_model_zh'] = asr_model_zh + else: + asr_model_zh = self.additional_models['asr_model_zh'] + + if 'wavlm_sv_model' not in self.additional_models: + wavlm_sv_extractor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-plus-sv') + wavlm_sv_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv') + wavlm_sv_model = wavlm_sv_model.to(device) + wavlm_sv_model = wavlm_sv_model.eval() + self.additional_models['wavlm_sv_model'] = wavlm_sv_model + self.additional_models['wavlm_sv_extractor'] = wavlm_sv_extractor + logging.info(f"Loaded SV Model: {wavlm_sv_model}") + else: + wavlm_sv_model = self.additional_models['wavlm_sv_model'] + wavlm_sv_extractor = self.additional_models['wavlm_sv_extractor'] + + # load MOS estimator model only if True. + if self.estimate_mos: + # load mos estimator. + if 'squim_mos_model' not in self.additional_models: + squim_mos_model_full = SQUIM_SUBJECTIVE.get_model().to(device) + self.additional_models['squim_mos_model'] = squim_mos_model_full + else: + squim_mos_model_full = self.additional_models['squim_mos_model'] + + # load non-matching reference clean audio. + ref_16khz_wav, _ = librosa.load(self.non_matching_ref_audio_filepath, sr=16000) + + # prepare MOS estimator by taking a single audio example as an input. + squim_mos_model = partial( + squim_mos_model_full, reference=torch.from_numpy(ref_16khz_wav).to(device).unsqueeze(0) + ) + + _exp_dir_path = self.logger.log_dir + _exp_dir_path = _exp_dir_path + '/Sample_Audios' + if not os.path.exists(_exp_dir_path): + os.mkdir(_exp_dir_path) + + squim_mos_list_pred = [] + squim_mos_list_context = [] + squim_mos_list_gt = [] + similarity_list = [] + similarity_list_wavlm = [] + pred_context_similarity_list = [] + pred_context_similarity_list_wavlm = [] + gt_context_similarity_list = [] + gt_context_similarity_list_wavlm = [] + question_type = [] + + # predicting audio + batch_size = output_tokens_combined.shape[0] + test_dataloader_batch_size = batch_size + # self.test_dataloader() is not defined during validation + if isinstance(self.test_dataloader(), torch.utils.data.DataLoader): + test_dataloader_batch_size = self.test_dataloader().batch_size + + # logging attention maps. + # empty attention_probs_all indicates self.frozen_model.enc_dec_model.return_all_crossattention_probs is False. + if len(attention_probs_all) != 0: + attention_probs_all = torch.cat(attention_probs_all, dim=2) # B, 12, dec_timesteps, enc_timesteps + attention_probs_all = attention_probs_all.mean(dim=1) # B, dec_timesteps, enc_timesteps + + for i in range(batch_size): + text_end_step = text_limits[i, 1].item() + text_start_step = text_limits[i, 0].item() + end_index = end_indices.get(i, output_tokens_combined.shape[2]) + if len(attention_probs_all) != 0: + attention_probs_example = attention_probs_all[i][ + : end_index - (1 + self.decoder_context_len), text_start_step:text_end_step + ] # T, enc_timesteps + attention_map = attention_probs_example.float().cpu().numpy().T + alignment_image = plot_alignment_to_numpy_for_speechllm( + attention_map, + phoneme_ver=1, + phoneme_seq=None, + ) + + if global_step is not None: + # During validation, step is simply global_step + i + step = global_step + i + else: + # During inference, step is the index of the sample + step = batch_idx * test_dataloader_batch_size + i + + self.logger.experiment.add_image( + "Inf Attention Map", + alignment_image, + step, + dataformats="HWC", + ) + # Save attention image to file + alignment_fp = os.path.join(_exp_dir_path, f'attention_map_{step}.png') + imageio.imwrite(alignment_fp, alignment_image) + + wer_score = 0 + audio_to_pred = [] + audio_to_pred_zh = [] + total_audio_seconds = 0 + for i in range(batch_size): + if global_step is not None: + # During validation, step is simply global_step + i + step = global_step + i + else: + # During inference, step is the index of the sample + step = batch_idx * test_dataloader_batch_size + i + + audio_len = self.decoder_context_len + (labels[i][0][self.decoder_context_len :] != 0).sum().item() + + if torch.count_nonzero(speech_mask) > 0: + dec_input_to_1024 = self.convert_tokens_to_range(dec_input_raw[i, :, 0:audio_len]) + dec_input_to_1024_answer = dec_input_to_1024[:, self.decoder_context_len + 1 :] + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024_answer) + self.logger.experiment.add_audio("Inf Dec Input Wav", dec_input_wav, step, self.sample_rate) + + predicted_tokens = output_tokens_combined[i] # Should not contain context even if decoder context + if i in end_indices: + logging.info(f"Clipping until end index for audio {i}") + if self.cfg.get('seq_pattern', 'parallel') == 'delay_parallel': + predicted_tokens = predicted_tokens[ + :, 0 : end_indices[i] - (1 + self.decoder_context_len) + self.num_speech_codebooks + ] # trim to audio length + else: + predicted_tokens = predicted_tokens[ + :, 0 : end_indices[i] - (1 + self.decoder_context_len) + ] # trim to audio length + + pred_img = predicted_tokens.data.cpu().float().numpy() + dec_inp_img = dec_input_to_1024.data.cpu().float().numpy() + start_time = time.time() + predicted_tokens = self.convert_tokens_to_range(predicted_tokens, apply_offset_correction=False) + predicted_wav = self.decode_wav_from_codec_model(predicted_tokens) + # accumulate audio length in seconds and process time in seconds to the RTF + total_process_time = total_process_time + (time.time() - start_time) + total_audio_seconds = total_audio_seconds + predicted_wav.size(-1) / self.sample_rate + + self.logger.experiment.add_audio("Inf Pred Wav", predicted_wav, step, self.sample_rate) + self.logger.experiment.add_image( + "Inf Pred Tokens", + plot_codec_to_numpy(pred_img), + step, + dataformats="HWC", + ) + self.logger.experiment.add_image( + "Inf Dec Input Tokens", + plot_codec_to_numpy(dec_inp_img), + step, + dataformats="HWC", + ) + + # save predicted_wav and gt_wav to a wav files in dir_path + if global_step is not None: + # During training, overwrite the wav file from the previous validation + wav_num = i + else: + wav_num = step + + audio_fp_pred = os.path.join(_exp_dir_path, f'predicted_wav_{wav_num}.wav') + sf.write(audio_fp_pred, predicted_wav.cpu().numpy(), self.sample_rate) + audio_fp_gt = os.path.join(_exp_dir_path, f'dec_input_wav_{wav_num}.wav') + sf.write(audio_fp_gt, dec_input_wav.cpu().numpy(), self.sample_rate) + + # speaker verification evaluation using nemo model + spk_embedding_pred = nemo_sv_model.get_embedding(audio_fp_pred) + spk_embedding_pred = spk_embedding_pred.cpu().detach().numpy().flatten() + spk_embedding_gt = nemo_sv_model.get_embedding(audio_fp_gt) + spk_embedding_gt = spk_embedding_gt.cpu().detach().numpy().flatten() + similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) + ) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Cossim Individual Sample', similarity, step) + similarity_list.append(similarity) + + # speaker verification evaluation using wavlm model + gt_16khz_wav, _ = librosa.load(audio_fp_gt, sr=16000) + pred_16khz_wav, _ = librosa.load(audio_fp_pred, sr=16000) + inputs_wavlm = wavlm_sv_extractor( + [pred_16khz_wav, gt_16khz_wav], padding=True, return_tensors="pt", sampling_rate=16000 + ) + for key in inputs_wavlm.keys(): + inputs_wavlm[key] = inputs_wavlm[key].to(device) + + with torch.no_grad(): + wavlm_embeddings = wavlm_sv_model(**inputs_wavlm).embeddings + wavlm_embeddings = torch.nn.functional.normalize(wavlm_embeddings, dim=-1).cpu() + + spk_embedding_pred_wavlm = wavlm_embeddings[0].cpu().detach().numpy().flatten() + spk_embedding_gt_wavlm = wavlm_embeddings[1].cpu().detach().numpy().flatten() + similarity_wavlm = np.dot(spk_embedding_pred_wavlm, spk_embedding_gt_wavlm) / ( + np.linalg.norm(spk_embedding_pred_wavlm) * np.linalg.norm(spk_embedding_gt_wavlm) + ) + similarity_list_wavlm.append(similarity_wavlm) + + if lang[i] == Lang.zh.value: + audio_to_pred_zh.append({"step": i, "audio": audio_fp_pred}) + audio_to_pred_zh.append({"step": i, "audio": audio_fp_gt}) + else: + audio_to_pred.append({"step": i, "audio": audio_fp_pred}) + audio_to_pred.append({"step": i, "audio": audio_fp_gt}) + + if isinstance(context_and_question_tokens, list): + context_tokens, question_tokens = context_and_question_tokens + input_token_list = [ + question_tokens[i, 0, j].item() + for j in range(context_and_question_tokens_lens[1][i].item()) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][i] + context_tokens = context_tokens[i][:, :context_end_step] + else: + input_token_list = [ + context_and_question_tokens[i, 0, j].item() + for j in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + context_tokens = context_and_question_tokens[i][:, :context_end_step] + + spk_embedding_context = spk_embedding_gt + spk_embedding_context_wavlm = spk_embedding_gt_wavlm + if self.decoder_context_len > 0: + context_tokens = dec_input_to_1024[:, : self.decoder_context_len + 1] + context_wav = self.decode_wav_from_codec_model(context_tokens) + elif context_end_step > 1: + is_speech_context = context_tokens[1, :].sum().item() > 0 + if is_speech_context: + context_tokens = self.convert_tokens_to_range(context_tokens, pattern=self.context_pattern) + context_wav = self.decode_wav_from_codec_model(context_tokens) + else: + context_wav = None + _context_token_list = [v.item() for v in context_tokens[0, :]] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Context Text", _context_text, self.global_step) + + else: + context_wav = None + + if context_wav is not None: + self.logger.experiment.add_audio("Context Wav", context_wav, step, self.sample_rate) + context_wav_fp = os.path.join(_exp_dir_path, f'context_wav_{wav_num}.wav') + sf.write(context_wav_fp, context_wav.cpu().numpy(), self.sample_rate) + # titanet + spk_embedding_context = nemo_sv_model.get_embedding(context_wav_fp) + spk_embedding_context = spk_embedding_context.cpu().detach().numpy().flatten() + # wavlm + context_wavlm_wav, _ = librosa.load(context_wav_fp, sr=16000) + inputs_wavlm = wavlm_sv_extractor( + [context_wavlm_wav], padding=True, return_tensors="pt", sampling_rate=16000 + ) + for key in inputs_wavlm.keys(): + inputs_wavlm[key] = inputs_wavlm[key].to(device) + + with torch.no_grad(): + wavlm_embeddings = wavlm_sv_model(**inputs_wavlm).embeddings + wavlm_embeddings = torch.nn.functional.normalize(wavlm_embeddings, dim=-1).cpu() + + spk_embedding_context_wavlm = wavlm_embeddings[0].cpu().detach().numpy().flatten() + + pred_similarity_context = np.dot(spk_embedding_context, spk_embedding_pred) / ( + np.linalg.norm(spk_embedding_context) * np.linalg.norm(spk_embedding_pred) + ) + gt_similarity_context = np.dot(spk_embedding_context, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_context) * np.linalg.norm(spk_embedding_gt) + ) + + pred_similarity_context_wavlm = np.dot(spk_embedding_context_wavlm, spk_embedding_pred_wavlm) / ( + np.linalg.norm(spk_embedding_context_wavlm) * np.linalg.norm(spk_embedding_pred_wavlm) + ) + gt_similarity_context_wavlm = np.dot(spk_embedding_context_wavlm, spk_embedding_gt_wavlm) / ( + np.linalg.norm(spk_embedding_context_wavlm) * np.linalg.norm(spk_embedding_gt_wavlm) + ) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Cossim Context Pred', pred_similarity_context, step) + self.logger.experiment.add_scalar(f'Inf SV Cossim Context GT', gt_similarity_context, step) + pred_context_similarity_list.append(pred_similarity_context) + gt_context_similarity_list.append(gt_similarity_context) + pred_context_similarity_list_wavlm.append(pred_similarity_context_wavlm) + gt_context_similarity_list_wavlm.append(gt_similarity_context_wavlm) + + task_question = self.frozen_model.tokenizer.ids_to_text( + [v[1] for v in input_token_list if v[1] < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Inf Task Question", task_question, step) + if "Phoneme TTS" in task_question: + question_type.append("Phoneme TTS") + elif "Text to speech this" in task_question: + question_type.append("Text to speech this") + else: + question_type.append("Other") + + task_question_phoneme_tokens = [ + v[1] - self.lm_vocab_size for v in input_token_list if v[1] >= self.lm_vocab_size + ] + if len(task_question_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(task_question_phoneme_tokens) + self.logger.experiment.add_text("Inf Task Question Phoneme Text", phoneme_text, step) + + # store predicted_tokens for each layer to compute token error rate + for layer_idx in range(self.num_speech_codebooks): + ter_dict[layer_idx]['hypothesis'].append(predicted_tokens[layer_idx].cpu().numpy().tolist()) + ter_dict[layer_idx]['gt'].append(dec_input_to_1024_answer[layer_idx].cpu().numpy().tolist()) + + # estimate MOS scores. + if self.estimate_mos: + squim_mos_score_pred = squim_mos_model( + torch.from_numpy(pred_16khz_wav).to(device).unsqueeze(0) + ).item() + squim_mos_score_gt = squim_mos_model( + torch.from_numpy(gt_16khz_wav).to(device).unsqueeze(0) + ).item() + if context_wav is not None: + squim_mos_score_context = squim_mos_model(context_wav.to(device).unsqueeze(0)).item() + squim_mos_list_context.append(squim_mos_score_context) + squim_mos_list_pred.append(squim_mos_score_pred) + squim_mos_list_gt.append(squim_mos_score_gt) + else: + r = labels[i, 0].long() + nzm = r != 0 + r = r.tolist()[:-1] + nzm = nzm[:-1] + h = output_tokens_combined[i].long() * nzm + h = h.tolist() + cur_wer_score = editdistance.eval(r, h) + if log_scalars: + self.logger.experiment.add_scalar('WER', cur_wer_score, step) + logging.info(f"current wer score : {cur_wer_score}") + wer_score += cur_wer_score + if wer_score > 0: + wer_score /= batch_size + if log_scalars: + self.logger.experiment.add_scalar('AVG WER', wer_score, step) + logging.info(f"average wer score : {wer_score}") + + # compute token error rate for each layer + if log_scalars: + for layer_idx in range(self.num_speech_codebooks): + wer = word_error_rate(ter_dict[layer_idx]['hypothesis'], ter_dict[layer_idx]['gt'], use_cer=True) + self.logger.experiment.add_scalar(f'Inf TER Layer {layer_idx}', wer, 0) + + greedy_transcripts = [] + if len(audio_to_pred) > 0: + greedy_transcripts.extend(asr_model.transcribe([i["audio"] for i in audio_to_pred])[0]) + if len(audio_to_pred_zh) > 0: + greedy_transcripts.extend(asr_model_zh.transcribe([i["audio"] for i in audio_to_pred_zh])[0]) + + all_audio_to_pred = audio_to_pred + audio_to_pred_zh + # Note WER over the batch is not equal to WER(sample) / batch_size, but approx. here + + # These are between ASR outputs of GT audio and predicted audio + wer_batch = [] + cer_batch = [] + cer_phoneme = [] + wer_phoneme = [] + cer_tts = [] + wer_tts = [] + + # These are between ASR output of Pred audio and GT text + wer_batch_gt = [] + cer_batch_gt = [] + cer_phoneme_gt = [] + wer_phoneme_gt = [] + cer_tts_gt = [] + wer_tts_gt = [] + + for i in range(0, len(greedy_transcripts) - 1, 2): + assert all_audio_to_pred[i]["step"] == all_audio_to_pred[i + 1]["step"] + step = batch_idx * test_dataloader_batch_size + all_audio_to_pred[i]["step"] + question_text = question_texts[i // 2] + + # No need to process text since both are ASR outputs + cer_sample = word_error_rate([greedy_transcripts[i]], [greedy_transcripts[i + 1]], use_cer=True) + wer_sample = word_error_rate([greedy_transcripts[i]], [greedy_transcripts[i + 1]], use_cer=False) + + # Processing text since one is ASR output and the other is the GT text + cer_gt = word_error_rate( + [self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=True + ) + wer_gt = word_error_rate( + [self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=False + ) + + self.logger.experiment.add_text("Inf Predicted Text", greedy_transcripts[i], step) + self.logger.experiment.add_text("Inf GT Text", greedy_transcripts[i + 1], step) + self.logger.experiment.add_text("Inf Question Text", question_text, step) + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER Transcript', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER Transcript', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT Transcript', cer_gt, step) + cer_batch.append(cer_sample) + wer_batch.append(wer_sample) + cer_batch_gt.append(cer_gt) + wer_batch_gt.append(wer_gt) + if question_type[all_audio_to_pred[i]["step"]] == "Phoneme TTS": + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER Phoneme Task', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER Phoneme Task', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT Phoneme Task', cer_gt, step) + cer_phoneme.append(cer_sample) + wer_phoneme.append(wer_sample) + cer_phoneme_gt.append(cer_gt) + wer_phoneme_gt.append(wer_gt) + elif question_type[all_audio_to_pred[i]["step"]] == "Text to speech this": + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER TTS Task', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER TTS Task', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT TTS Task', cer_gt, step) + cer_tts.append(cer_sample) + wer_tts.append(wer_sample) + cer_tts_gt.append(cer_gt) + wer_tts_gt.append(wer_gt) + + # compute average similarity + similarity_avg = np.mean(similarity_list) + pred_context_similarity_avg = np.mean(pred_context_similarity_list) + gt_context_similarity_avg = np.mean(gt_context_similarity_list) + similarity_avg_wavlm = np.mean(similarity_list_wavlm) + pred_context_similarity_avg_wavlm = np.mean(pred_context_similarity_list_wavlm) + gt_context_similarity_avg_wavlm = np.mean(gt_context_similarity_list_wavlm) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Avg Cossim', similarity_avg, batch_idx) + self.predict_step_outputs.append( + { + 'titanet_avg_cossim': similarity_avg, + 'titanet_avg_cossim_context_pred': pred_context_similarity_avg, + 'titanet_avg_cossim_context_gt': gt_context_similarity_avg, + 'wavlm_avg_cossim': similarity_avg_wavlm, + 'wavlm_avg_cossim_context_pred': pred_context_similarity_avg_wavlm, + 'wavlm_avg_cossim_context_gt': gt_context_similarity_avg_wavlm, + 'squim_mos_pred': np.mean(squim_mos_list_pred) if len(squim_mos_list_pred) > 0 else None, + 'squim_mos_context': np.mean(squim_mos_list_context) if len(squim_mos_list_context) > 0 else None, + 'squim_mos_gt': np.mean(squim_mos_list_gt) if len(squim_mos_list_gt) > 0 else None, + 'cer_transcript': np.mean(cer_batch), + 'wer_transcript': np.mean(wer_batch), + 'cer_phoneme': np.mean(cer_phoneme) if len(cer_phoneme) > 0 else None, + 'wer_phoneme': np.mean(wer_phoneme) if len(wer_phoneme) > 0 else None, + 'cer_tts': np.mean(cer_tts) if len(cer_tts) > 0 else None, + 'wer_tts': np.mean(wer_tts) if len(wer_tts) > 0 else None, + 'cer_transcript_gt': np.mean(cer_batch_gt), + 'wer_transcript_gt': np.mean(wer_batch_gt), + 'cer_phoneme_gt': np.mean(cer_phoneme_gt) if len(cer_phoneme_gt) > 0 else None, + 'wer_phoneme_gt': np.mean(wer_phoneme_gt) if len(wer_phoneme_gt) > 0 else None, + 'cer_tts_gt': np.mean(cer_tts_gt) if len(cer_tts_gt) > 0 else None, + 'wer_tts_gt': np.mean(wer_tts_gt) if len(wer_tts_gt) > 0 else None, + "RTF": total_process_time / total_audio_seconds, + } + ) + + # TODO @xueyang: PTL 2.0+ patch. Signature of method `on_predict_epoch_end` does not match signature of the base method in PTL class 'ModelHooks'. + # Remove the `outputs` param and choose `self.predict_step_output` instead. + def on_predict_epoch_end(self, outputs: List[Any]) -> None: + + gather_results = [None for _ in range(parallel_state.get_data_parallel_world_size())] + all_preds = list(itertools.chain(*[item['preds_text'] for item in outputs[0]])) + all_labels = list(itertools.chain(*[item['labels_text'] for item in outputs[0]])) + all_inputs = list(itertools.chain(*[item['input_text'] for item in outputs[0]])) + + assert len(all_preds) == len(all_labels) + assert len(all_preds) == len(all_inputs) + + # Gather inputs, predictions, and ground truths from all workers + torch.distributed.all_gather_object( + gather_results, + [(input, pred, label) for (input, pred, label) in zip(all_inputs, all_preds, all_labels)], + group=parallel_state.get_data_parallel_group(), + ) + + # Deduplicate sentences that may have been distributed across multiple data parallel ranks. + if parallel_state.get_data_parallel_rank() == 0: + gather_results_dedup = list(set(itertools.chain(*gather_results))) + + input_prediction_pair = [] + correct = 0 + for input, pred, label in gather_results_dedup: + input_prediction_pair.append((input, pred)) + if label: + if pred == label: + correct += 1 + + acc = correct / len(gather_results_dedup) if all_labels[0] else None + logging.info(f'Prediction results: {acc}') + logging.info(f'Test finish') diff --git a/nemo/collections/tts/models/ssl_tts.py b/nemo/collections/tts/models/ssl_tts.py index 298a1a599008..f2cc4f798ec5 100644 --- a/nemo/collections/tts/models/ssl_tts.py +++ b/nemo/collections/tts/models/ssl_tts.py @@ -18,10 +18,10 @@ import librosa import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.utilities.combined_loader import CombinedLoader from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities.combined_loader import CombinedLoader from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss from nemo.collections.tts.data.dataset import TTSDataset @@ -38,10 +38,10 @@ class SSLDisentangler(ModelPT): """ SSLDisentangler is a Conformer based model for extracting disentangled content and speaker embeddings - from an audio waveform. This model uses a pre-trained Conformer SSL model. To extract the linguistic content - and speaker representations using a pre-trained Conformer, two randomly initialized downstream - heads are added and the entire setup is finetuned in multi-task manner for speech recognition and speaker verification. - These representations can be used by FastPitchModel_SSL for voice conversion by swapping the speaker embedding + from an audio waveform. This model uses a pre-trained Conformer SSL model. To extract the linguistic content + and speaker representations using a pre-trained Conformer, two randomly initialized downstream + heads are added and the entire setup is finetuned in multi-task manner for speech recognition and speaker verification. + These representations can be used by FastPitchModel_SSL for voice conversion by swapping the speaker embedding of a given source utterance, with the speaker embedding of a target speaker. """ @@ -92,7 +92,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): librosa_mel_filter = librosa.filters.mel( sr=stft_cfg.sample_rate, n_fft=stft_cfg.n_fft, n_mels=stft_cfg.features, fmin=0, fmax=8000 ) - fb = torch.tensor(librosa_mel_filter, dtype=torch.float,).unsqueeze(0) + fb = torch.tensor( + librosa_mel_filter, + dtype=torch.float, + ).unsqueeze(0) self.register_buffer("fb", fb) @@ -212,7 +215,10 @@ def configure_optimizers(self): sched_downstream_config = optim_downstream_config.pop("sched", None) OmegaConf.set_struct(optim_downstream_config, True) - optim_backbone = instantiate(optim_backbone_config, params=self.encoder.parameters(),) + optim_backbone = instantiate( + optim_backbone_config, + params=self.encoder.parameters(), + ) optim_downstream = instantiate( optim_downstream_config, params=itertools.chain( @@ -254,7 +260,8 @@ def configure_optimizers(self): def forward(self, input_signal=None, input_signal_length=None, normalize_content=True): processed_signal, processed_signal_length = self.preprocessor_disentangler( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) # b,c,t @@ -292,7 +299,9 @@ def forward_for_export(self, input_signal=None, input_signal_length=None, normal # Same as forward right now. Earlier version of encoder had a different forward for export. # This function is still kept for compatibility with older evaluation/inference scripts. return self.forward( - input_signal=input_signal, input_signal_length=input_signal_length, normalize_content=normalize_content, + input_signal=input_signal, + input_signal_length=input_signal_length, + normalize_content=normalize_content, ) def training_step(self, batch, batch_idx): diff --git a/nemo/collections/tts/models/tacotron2.py b/nemo/collections/tts/models/tacotron2.py index 2fb005d80ca6..33d476029011 100644 --- a/nemo/collections/tts/models/tacotron2.py +++ b/nemo/collections/tts/models/tacotron2.py @@ -18,9 +18,9 @@ import torch from hydra.utils import instantiate +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from omegaconf import MISSING, DictConfig, OmegaConf, open_dict from omegaconf.errors import ConfigAttributeError -from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from torch import nn from nemo.collections.common.parts.preprocessing import parsers diff --git a/nemo/collections/tts/models/univnet.py b/nemo/collections/tts/models/univnet.py index 64ee891b0754..12500be8d180 100644 --- a/nemo/collections/tts/models/univnet.py +++ b/nemo/collections/tts/models/univnet.py @@ -18,8 +18,8 @@ import torch import torch.nn.functional as F from hydra.utils import instantiate +from lightning.pytorch.loggers.wandb import WandbLogger from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.loggers.wandb import WandbLogger from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, GeneratorLoss from nemo.collections.tts.losses.stftlosses import MultiResolutionSTFTLoss @@ -114,8 +114,14 @@ def configure_optimizers(self): if sched_config is None and 'sched' in self._cfg: sched_config = self._cfg.sched - optim_g = instantiate(optim_config, params=self.generator.parameters(),) - optim_d = instantiate(optim_config, params=itertools.chain(self.mrd.parameters(), self.mpd.parameters()),) + optim_g = instantiate( + optim_config, + params=self.generator.parameters(), + ) + optim_d = instantiate( + optim_config, + params=itertools.chain(self.mrd.parameters(), self.mpd.parameters()), + ) if sched_config is not None: max_steps = self._cfg.get("max_steps", None) @@ -290,7 +296,7 @@ def stft(x): comp = torch.stft(x.squeeze(1), n_fft=1024, hop_length=256, win_length=1024, return_complex=True) comp = torch.view_as_real(comp) real, imag = comp[..., 0], comp[..., 1] - mags = torch.sqrt(real ** 2 + imag ** 2) + mags = torch.sqrt(real**2 + imag**2) phase = torch.atan2(imag, real) return mags, phase diff --git a/nemo/collections/tts/models/vits.py b/nemo/collections/tts/models/vits.py index 4a891fa8823e..3c53442a0863 100644 --- a/nemo/collections/tts/models/vits.py +++ b/nemo/collections/tts/models/vits.py @@ -18,9 +18,9 @@ import omegaconf import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import WandbLogger from torch.cuda.amp import autocast from torch.nn import functional as F diff --git a/nemo/collections/tts/models/waveglow.py b/nemo/collections/tts/models/waveglow.py index 728b5b94b084..04eec734b26e 100644 --- a/nemo/collections/tts/models/waveglow.py +++ b/nemo/collections/tts/models/waveglow.py @@ -15,8 +15,8 @@ import torch from hydra.utils import instantiate +from lightning.pytorch.loggers import TensorBoardLogger from omegaconf import DictConfig, open_dict -from pytorch_lightning.loggers import TensorBoardLogger from nemo.collections.tts.losses.waveglowloss import WaveGlowLoss from nemo.collections.tts.models.base import GlowVocoder diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index c4ec09031cf9..1856dee0ce0f 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -23,10 +23,10 @@ import soundfile as sf import torch from einops import rearrange -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.loggers.logger import Logger -from pytorch_lightning.loggers.wandb import WandbLogger +from lightning.pytorch import Callback, LightningModule, Trainer +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers.logger import Logger +from lightning.pytorch.loggers.wandb import WandbLogger from torch import Tensor from nemo.collections.tts.parts.utils.helpers import create_plot @@ -194,7 +194,10 @@ def _log_audio(self, audio: AudioArtifact, log_dir: Path, step: int): if self.tensorboard_logger: self.tensorboard_logger.add_audio( - tag=audio.id, snd_tensor=audio.data, global_step=step, sample_rate=audio.sample_rate, + tag=audio.id, + snd_tensor=audio.data, + global_step=step, + sample_rate=audio.sample_rate, ) if self.wandb_logger: @@ -212,7 +215,10 @@ def _log_image(self, image: ImageArtifact, log_dir: Path, step: int): if self.tensorboard_logger: self.tensorboard_logger.add_image( - tag=image.id, img_tensor=image_plot, global_step=step, dataformats="HWC", + tag=image.id, + img_tensor=image_plot, + global_step=step, + dataformats="HWC", ) if self.wandb_logger: @@ -220,8 +226,7 @@ def _log_image(self, image: ImageArtifact, log_dir: Path, step: int): self.wandb_logger.log({image.id: wandb_image}) def _log_artifacts(self, audio_list: list, image_list: list, log_dir: Optional[Path] = None, global_step: int = 0): - """Log audio and image artifacts. - """ + """Log audio and image artifacts.""" if log_dir is not None: log_dir.mkdir(parents=True, exist_ok=True) @@ -232,8 +237,7 @@ def _log_artifacts(self, audio_list: list, image_list: list, log_dir: Optional[P self._log_image(image=image, log_dir=log_dir, step=global_step) def on_fit_start(self, trainer: Trainer, model: LightningModule): - """Log initial data artifacts. - """ + """Log initial data artifacts.""" audio_list = [] image_list = [] for batch_dict in self.data_loader: @@ -255,8 +259,7 @@ def on_fit_start(self, trainer: Trainer, model: LightningModule): self._log_artifacts(audio_list=audio_list, image_list=image_list, log_dir=log_dir) def on_train_epoch_end(self, trainer: Trainer, model: LightningModule): - """Log artifacts at the end of an epoch. - """ + """Log artifacts at the end of an epoch.""" epoch = 1 + model.current_epoch if (epoch not in self.log_epochs) and (epoch % self.epoch_frequency != 0): return @@ -306,7 +309,10 @@ def generate_artifacts( audio_gt_path = Path(f"{dataset_name}/{audio_id}_gt.wav") audio_gt_i = audio[i, : audio_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_gt_{audio_id}", data=audio_gt_i, filepath=audio_gt_path, sample_rate=model.sample_rate, + id=f"audio_gt_{audio_id}", + data=audio_gt_i, + filepath=audio_gt_path, + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) return audio_artifacts, [] @@ -321,7 +327,10 @@ def generate_artifacts( audio_pred_path = Path(f"{dataset_name}/{audio_id}.wav") audio_pred_i = audio_pred[i, : audio_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_{audio_id}", data=audio_pred_i, filepath=audio_pred_path, sample_rate=model.sample_rate, + id=f"audio_{audio_id}", + data=audio_pred_i, + filepath=audio_pred_path, + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) @@ -378,7 +387,10 @@ def _generate_audio( audio_pred_path = Path(f"{dataset_name}/{audio_id}_audio_out.wav") audio_pred_i = audio_pred[i, : audio_pred_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_out_{audio_id}", data=audio_pred_i, filepath=audio_pred_path, sample_rate=model.sample_rate, + id=f"audio_out_{audio_id}", + data=audio_pred_i, + filepath=audio_pred_path, + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) @@ -388,7 +400,10 @@ def _generate_audio( audio_in_path = Path(f"{dataset_name}/{audio_id}_audio_in.wav") audio_in_i = audio[i, : audio_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_in_{audio_id}", data=audio_in_i, filepath=audio_in_path, sample_rate=model.sample_rate, + id=f"audio_in_{audio_id}", + data=audio_in_i, + filepath=audio_in_path, + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) @@ -538,7 +553,11 @@ def _create_ground_truth_artifacts( spec_gt_path = Path(f"{dataset_name}/{audio_id}_spec_gt.png") spec_gt_i = spec[i, :, : spec_len[i]].cpu().numpy() spec_artifact = ImageArtifact( - id=f"spec_{audio_id}", data=spec_gt_i, filepath=spec_gt_path, x_axis="Audio Frames", y_axis="Channels", + id=f"spec_{audio_id}", + data=spec_gt_i, + filepath=spec_gt_path, + x_axis="Audio Frames", + y_axis="Channels", ) image_artifacts.append(spec_artifact) @@ -565,14 +584,22 @@ def _generate_predictions( with torch.no_grad(): # [B, C, T_spec] - mels_pred, mels_pred_len, *_ = model.forward(text=text, input_lens=text_lens, speaker=speaker,) + mels_pred, mels_pred_len, *_ = model.forward( + text=text, + input_lens=text_lens, + speaker=speaker, + ) if self.log_spectrogram: for i, (dataset_name, audio_id) in enumerate(zip(dataset_names, audio_ids)): spec_path = Path(f"{dataset_name}/{audio_id}_spec.png") spec_i = mels_pred[i, :, : mels_pred_len[i]].cpu().numpy() spec_artifact = ImageArtifact( - id=f"spec_{audio_id}", data=spec_i, filepath=spec_path, x_axis="Audio Frames", y_axis="Channels", + id=f"spec_{audio_id}", + data=spec_i, + filepath=spec_path, + x_axis="Audio Frames", + y_axis="Channels", ) image_artifacts.append(spec_artifact) diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index a4c65f9ed0e5..28be259502c5 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -48,8 +48,8 @@ import librosa import matplotlib.pylab as plt import numpy as np +import seaborn as sns import torch -from einops import rearrange from numba import jit, prange from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens @@ -63,7 +63,7 @@ HAVE_WANDB = False try: - from pytorch_lightning.utilities import rank_zero_only + from lightning.pytorch.utilities import rank_zero_only except ModuleNotFoundError: from functools import wraps @@ -468,6 +468,74 @@ def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vm return data +def plot_alignment_to_numpy_for_speechllm( + alignment, + title='', + info=None, + phoneme_seq=None, + vmin=None, + vmax=None, + phoneme_ver=0, + phone_offset=2, + h_offset=True, +): + alignment = np.clip(alignment, a_min=0, a_max=None) + fig, ax = plt.subplots(figsize=(8, 6)) + im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax) + ax.set_title(title) + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + + if phoneme_seq is not None: + if phoneme_ver == 0: + # for debugging of phonemes and durs in maps. Not used by def in training code + ax.set_yticks(np.arange(len(phoneme_seq))) + ax.set_yticklabels(phoneme_seq) + ax.hlines(np.arange(len(phoneme_seq)), xmin=0.0, xmax=max(ax.get_xticks())) + elif phoneme_ver == 1: + yticks = ax.get_yticks() + new_yticks = [] + for tick in yticks: + if tick < 0 or tick > alignment.shape[0]: + continue + new_yticks.append(tick) + new_yticks += phoneme_seq + ax.set_yticks(new_yticks) + elif phoneme_ver == 2: + phones = phoneme_seq[phone_offset:] + ax.set_yticks(np.arange(len(phones))) + ax.set_yticklabels(phones) + ax.hlines(np.arange(0.5, len(phones) - 0.5, 1.0), xmin=0.0, xmax=alignment.shape[1] - 0.5, colors="black") + + if h_offset: + xticks = ax.get_xticks() + new_xticks = [] + for tick in xticks: + new_xticks.append(f"{tick+phoneme_seq[1]:.0f}") + ax.set_xticklabels(new_xticks) + + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def plot_codec_to_numpy(codes, title=''): + fig, ax = plt.subplots(figsize=(10, 3)) + sns.heatmap(codes, ax=ax) + + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + def plot_pitch_to_numpy(pitch, ylim_range=None): fig, ax = plt.subplots(figsize=(12, 3)) plt.plot(pitch) diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index 5f1185c2c399..96806f633a54 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -67,8 +67,7 @@ def get_audio_filepaths(manifest_entry: Dict[str, Any], audio_dir: Path) -> Tupl def normalize_volume(audio: np.array, volume_level: float = 0.95) -> np.array: - """Apply peak normalization to the input audio. - """ + """Apply peak normalization to the input audio.""" if not (0.0 <= volume_level <= 1.0): raise ValueError(f"Volume must be in range [0.0, 1.0], received {volume_level}") @@ -88,10 +87,11 @@ class BetaBinomialInterpolator: The implementation is taken from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py """ - def __init__(self, round_mel_len_to=50, round_text_len_to=10, cache_size=500): + def __init__(self, round_mel_len_to=50, round_text_len_to=10, cache_size=500, scaling_factor: float = 1.0): self.round_mel_len_to = round_mel_len_to self.round_text_len_to = round_text_len_to - self.bank = functools.lru_cache(maxsize=cache_size)(beta_binomial_prior_distribution) + cached_func = lambda x, y: beta_binomial_prior_distribution(x, y, scaling_factor=scaling_factor) + self.bank = functools.lru_cache(maxsize=cache_size)(cached_func) @staticmethod def round(val, to): @@ -315,7 +315,11 @@ def load_audio( def sample_audio( - manifest_entry: Dict[str, Any], audio_dir: Path, sample_rate: int, n_samples: int, volume_norm: bool = False, + manifest_entry: Dict[str, Any], + audio_dir: Path, + sample_rate: int, + n_samples: int, + volume_norm: bool = False, ) -> Tuple[np.ndarray, Path, Path]: """ Randomly sample an audio segment from a manifest entry. diff --git a/nemo/collections/vision/models/megatron_vit_classification_models.py b/nemo/collections/vision/models/megatron_vit_classification_models.py index 5cffdd6d12a3..c4024a5a47a7 100644 --- a/nemo/collections/vision/models/megatron_vit_classification_models.py +++ b/nemo/collections/vision/models/megatron_vit_classification_models.py @@ -17,9 +17,9 @@ from typing import Any, Optional import torch +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingSampler from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel diff --git a/nemo/collections/vlm/mllama/data/lazy.py b/nemo/collections/vlm/mllama/data/lazy.py index 30b8b2ea9d9c..5069f8593377 100644 --- a/nemo/collections/vlm/mllama/data/lazy.py +++ b/nemo/collections/vlm/mllama/data/lazy.py @@ -18,10 +18,10 @@ import re from typing import Any, Dict, List, Optional, Sequence -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn.functional as F -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, default_collate diff --git a/nemo/collections/vlm/mllama/data/mock.py b/nemo/collections/vlm/mllama/data/mock.py index bb3afe83ea46..a88838b0025f 100644 --- a/nemo/collections/vlm/mllama/data/mock.py +++ b/nemo/collections/vlm/mllama/data/mock.py @@ -14,10 +14,10 @@ from typing import Dict, List, Optional, Tuple +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset diff --git a/nemo/collections/vlm/mllama/model/base.py b/nemo/collections/vlm/mllama/model/base.py index f03af078987d..7dd84fefbb18 100644 --- a/nemo/collections/vlm/mllama/model/base.py +++ b/nemo/collections/vlm/mllama/model/base.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple -import pytorch_lightning as L +import lightning.pytorch as L import torch import torch.distributed from einops import rearrange diff --git a/nemo/collections/vlm/mllama/model/vision.py b/nemo/collections/vlm/mllama/model/vision.py index f662546d21ae..f023cc7bf943 100644 --- a/nemo/collections/vlm/mllama/model/vision.py +++ b/nemo/collections/vlm/mllama/model/vision.py @@ -59,6 +59,9 @@ def to_2tuple(x): + """ + Convert an input to a 2-tuple. + """ if isinstance(x, collections.abc.Iterable): return x return (x, x) @@ -71,9 +74,16 @@ def _stack_images( max_num_images: int, ) -> Tuple[torch.Tensor, List[int]]: """ - Takes a list of list of images and stacks them into a tensor. - This function is needed since images can be of completely - different resolutions and aspect ratios. + Stack a list of image lists into a tensor while accounting for varying resolutions and aspect ratios. + + Args: + images (List[List[PIL_Image.Image]]): List of image lists for stacking. + max_num_chunks (int): Maximum number of chunks per image. + image_res (int): Target resolution for each image. + max_num_images (int): Maximum number of images to stack. + + Returns: + Tuple[torch.Tensor, List[int]]: Tensor of stacked images and a list of chunk counts for each image. """ out_images, out_num_chunks = [], [] for imgs_sample in images: @@ -97,7 +107,17 @@ def build_encoder_attention_mask( x: torch.Tensor, ar_ids: torch.Tensor, ntok: int, num_chunks: int, supported_aspect_ratios: List[List[int]] ): """ - Build vision encoder attention mask that omits padding tiles and tokens. + Build attention masks for a vision encoder to handle padding and token alignment. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, sequence_length). + ar_ids (torch.Tensor): Aspect ratio IDs for masking. + ntok (int): Number of tokens. + num_chunks (int): Number of chunks in the data. + supported_aspect_ratios (List[List[int]]): List of supported aspect ratios. + + Returns: + torch.Tensor: Tensor containing the attention mask. """ masks = [] for ar_id in ar_ids: @@ -113,6 +133,9 @@ def build_encoder_attention_mask( def apply_scaling(freqs: torch.Tensor): + """ + Scale frequency values based on predefined thresholds and a smoothing factor. + """ # Values obtained from grid search scale_factor = 8 low_freq_factor = 1 @@ -137,6 +160,9 @@ def apply_scaling(freqs: torch.Tensor): # Use this spec for an implementation using modules in TE def get_image_transformer_layer_spec() -> ModuleSpec: + """ + Create a specification for an image transformer layer. + """ image_transformer_submodules = TransformerLayerSubmodules( input_layernorm=TENorm, self_attention=ModuleSpec( @@ -175,6 +201,10 @@ def forward_with_return_intermediate( packed_seq_params: PackedSeqParams = None, return_intermediate: List[int] = None, ): + """ + Perform a forward pass through the transformer layers with optional intermediate outputs. + Override regular MCore transformer layer forward pass. + """ # hidden_states (float): [s, b, h] # attention_mask (bool): [1, 1, s, s] @@ -278,16 +308,22 @@ def forward_with_return_intermediate( class ColumnParallelConv2dPatch(MegatronModule): - """Conv2D Patching layer with model parallelism. - Column parallel over unfolded input. - Arguments: - in_channels: Input channels. - out_channels: Output channels. - kernel_size: Size of convolution kernel. - stride (default 1): Stride for convolution. - bias (default False): Use bias in Conv2d. - Input: (bsz, in_channels, width, height) - Output: (bsz, num_tokens, out_channels) + """ + Conv2D Patching layer with model parallelism. Applies convolution in a column-parallel fashion. + + Args: + config (TransformerConfig): Configuration object for the layer. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (Union[int, Tuple[int, int]]): Size of the convolution kernel. + stride (Union[int, Tuple[int, int]]): Stride of the convolution. + bias (Optional[bool], default=False): Whether to include a bias term. + + Input: + torch.Tensor: Input tensor of shape (batch_size, in_channels, width, height). + + Output: + torch.Tensor: Output tensor of shape (batch_size, num_tokens, out_channels). """ def __init__( @@ -316,6 +352,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward.""" x = self._unfold(x) x = x.permute(0, 2, 1) x = F.linear(x, self._linear.weight) @@ -324,6 +361,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PrecomputedTilePositionEmbedding(torch.nn.Module): + """ + Module to compute positional embeddings for tiles with optional gating. + + Args: + config (TransformerConfig): Configuration object. + gated (bool, default=False): Whether to apply gating to the embeddings. + """ + def __init__( self, config: TransformerConfig, @@ -340,6 +385,7 @@ def __init__( self.gate = nn.Parameter(torch.zeros(1)) def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + """Forward.""" embeddings = self.embedding(aspect_ratio_ids) embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) @@ -351,7 +397,15 @@ def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) - class SelfAttentionNoBias(SelfAttention): - """Self-attention layer class without bias""" + """ + Self-attention layer implementation without bias. + + Args: + config (TransformerConfig): Configuration for the transformer. + submodules (SelfAttentionSubmodules): Submodules required for self-attention. + layer_number (int): The layer number in the transformer stack. + attn_mask_type (AttnMaskType): Type of attention mask to apply. + """ def __init__( self, @@ -396,6 +450,16 @@ def __init__( class ImageTransformerLayer(TransformerLayer): + """ + Transformer layer adapted for processing image data with optional gating. + + Args: + config (TransformerConfig): Transformer configuration object. + submodules (TransformerLayerSubmodules): Submodules to use in the layer. + layer_number (int, default=1): Layer number in the transformer. + hidden_dropout (float, optional): Dropout rate for hidden layers. + """ + def __init__( self, config: TransformerConfig, @@ -423,9 +487,11 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): + """Forward.""" # hidden_states: [s, b, h] # Residual connection. @@ -485,6 +551,19 @@ def forward( class VisionEncoder(MegatronModule): + """ + Vision encoder module for processing image inputs with patch-based embeddings. + + Args: + config ('CrossAttentionVisionConfig'): Configuration object for the encoder. + image_size (int, default=560): Input image size. + patch_size (int, default=14): Size of patches extracted from the image. + in_channels (int, default=3): Number of input channels. + pre_process (bool, default=True): Whether to preprocess input. + post_process (bool, default=True): Whether to postprocess output. + return_intermediate (Optional[bool]): Whether to return intermediate layers. + """ + def __init__( self, config: 'CrossAttentionVisionConfig', @@ -556,7 +635,7 @@ def __init__( self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1)) def apply_positional_embedding(self, x, aspect_ratio_ids): - # apply regular position embedding + """Apply regular position embedding and tile positonal embedding.""" bsz, num_chunks, num_tokens, dim = x.shape x = x.view(bsz * num_chunks, num_tokens, dim) x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh()) @@ -567,6 +646,7 @@ def apply_positional_embedding(self, x, aspect_ratio_ids): return x def apply_class_embedding(self, x): + """Concat class embedding tokens.""" x = torch.cat( [ self.class_embedding.to(x.dtype) @@ -578,6 +658,7 @@ def apply_class_embedding(self, x): return x def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: + """Forward.""" if images.ndim == 5: num_concurrent_media = 1 bsz, num_chunks, nch, w, h = images.shape @@ -617,7 +698,8 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: return_intermediate=self.return_intermediate, ) - # [ntok * num_concurrent_media * num_chunks, bsz, hidden_size] -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size] + # [ntok * num_concurrent_media * num_chunks, bsz, hidden_size] + # -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size] x, int_x = x.transpose(0, 1).contiguous(), int_x.transpose(0, 1).contiguous() x = self.ln_post(x) x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim) diff --git a/nemo/collections/vlm/neva/data/api.py b/nemo/collections/vlm/neva/data/api.py index c2e51e033d8a..15ba45c82fd9 100644 --- a/nemo/collections/vlm/neva/data/api.py +++ b/nemo/collections/vlm/neva/data/api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.vlm.neva.data.lazy import NevaLazyDataModule from nemo.collections.vlm.neva.data.mock import MockDataModule diff --git a/nemo/collections/vlm/neva/data/lazy.py b/nemo/collections/vlm/neva/data/lazy.py index 57aa5b408835..fddaca14faeb 100644 --- a/nemo/collections/vlm/neva/data/lazy.py +++ b/nemo/collections/vlm/neva/data/lazy.py @@ -20,12 +20,12 @@ from typing import Any, Dict, List, Optional, Sequence import decord +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch import torch.nn.functional as F +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from PIL import Image -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset, default_collate from transformers import CLIPImageProcessor, SiglipImageProcessor diff --git a/nemo/collections/vlm/neva/data/mock.py b/nemo/collections/vlm/neva/data/mock.py index ac4bc56a068c..ede06e9f5778 100644 --- a/nemo/collections/vlm/neva/data/mock.py +++ b/nemo/collections/vlm/neva/data/mock.py @@ -14,10 +14,10 @@ from typing import Dict, List, Optional +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset diff --git a/nemo/collections/vlm/neva/model/api.py b/nemo/collections/vlm/neva/model/api.py index 62374d536712..19e94c70381e 100644 --- a/nemo/collections/vlm/neva/model/api.py +++ b/nemo/collections/vlm/neva/model/api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.vlm.neva.model import Llava1_5Config7B, Llava1_5Config13B, LlavaModel diff --git a/nemo/collections/vlm/neva/model/base.py b/nemo/collections/vlm/neva/model/base.py index 260b7e7e0f4a..d4e578218ed2 100644 --- a/nemo/collections/vlm/neva/model/base.py +++ b/nemo/collections/vlm/neva/model/base.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Union -import pytorch_lightning as L +import lightning.pytorch as L import torch import torch.distributed import torch.nn.functional as F diff --git a/nemo/collections/vlm/recipes/mllama_11b.py b/nemo/collections/vlm/recipes/mllama_11b.py index 697be9990faf..e4842ae63d52 100644 --- a/nemo/collections/vlm/recipes/mllama_11b.py +++ b/nemo/collections/vlm/recipes/mllama_11b.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo import lightning as nl diff --git a/nemo/collections/vlm/recipes/mllama_90b.py b/nemo/collections/vlm/recipes/mllama_90b.py index 8822aa9b189f..28a6ff7ff9a6 100644 --- a/nemo/collections/vlm/recipes/mllama_90b.py +++ b/nemo/collections/vlm/recipes/mllama_90b.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo import lightning as nl diff --git a/nemo/core/classes/__init__.py b/nemo/core/classes/__init__.py index 3a6db2602648..e773972c6d7b 100644 --- a/nemo/core/classes/__init__.py +++ b/nemo/core/classes/__init__.py @@ -14,8 +14,8 @@ import hydra +import lightning.pytorch import omegaconf -import pytorch_lightning from nemo.core.classes.common import ( FileIO, diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index aab09d42d907..ba284e7c28cd 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -15,7 +15,7 @@ from typing import Dict, List, Optional, Union import torch -from pytorch_lightning.core.module import _jit_is_scripting +from lightning.pytorch.core.module import _jit_is_scripting from nemo.core.classes import typecheck from nemo.core.neural_types import NeuralType diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index a15f769e9d88..308f5cbb8bee 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -35,9 +35,9 @@ HAVE_MEGATRON_CORE = False +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.utilities import model_summary, rank_zero_only from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities import model_summary, rank_zero_only from nemo import package_info from nemo.core import optim @@ -79,7 +79,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): """ if trainer is not None and not isinstance(trainer, Trainer): raise ValueError( - f"trainer constructor argument must be either None or pytorch_lightning.Trainer. But got {type(trainer)} instead." + f"trainer constructor argument must be either None or lightning.pytorch.Trainer. But got {type(trainer)} instead." ) super().__init__() diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index cd9971a9c383..2c4c826d1daf 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -23,9 +23,9 @@ from typing import Callable, Generator, Optional, Set, Union import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, OmegaConf from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.core import classes as nemo_classes # to avoid circular import do not import ModelPT directly from nemo.utils import logging, model_utils diff --git a/nemo/core/utils/k2_guard.py b/nemo/core/utils/k2_guard.py index a9f64ce39c6b..b0e86d319ec0 100644 --- a/nemo/core/utils/k2_guard.py +++ b/nemo/core/utils/k2_guard.py @@ -21,8 +21,9 @@ import textwrap +from lightning.pytorch.utilities.imports import package_available from packaging.version import Version -from pytorch_lightning.utilities.imports import package_available + from nemo.core.utils.k2_utils import K2_INSTALLATION_MESSAGE __K2_MINIMUM_MAJOR_VERSION = 1 diff --git a/nemo/deploy/deploy_base.py b/nemo/deploy/deploy_base.py index 63746199bac6..41e0e7ddbdc9 100644 --- a/nemo/deploy/deploy_base.py +++ b/nemo/deploy/deploy_base.py @@ -18,7 +18,7 @@ use_pytorch_lightning = True try: - from pytorch_lightning import Trainer + from lightning.pytorch import Trainer except Exception: use_pytorch_lightning = False diff --git a/nemo/deploy/nlp/megatronllm_deployable.py b/nemo/deploy/nlp/megatronllm_deployable.py index 64cf6114ceba..0ce5991cdc95 100644 --- a/nemo/deploy/nlp/megatronllm_deployable.py +++ b/nemo/deploy/nlp/megatronllm_deployable.py @@ -20,7 +20,7 @@ import numpy as np import torch import wrapt -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.text_generation_utils import ( diff --git a/nemo/deploy/nlp/query_llm.py b/nemo/deploy/nlp/query_llm.py index 7e873db6b5b1..e1d21bb54b76 100644 --- a/nemo/deploy/nlp/query_llm.py +++ b/nemo/deploy/nlp/query_llm.py @@ -174,6 +174,7 @@ def query_llm( end_strings=None, init_timeout=60.0, openai_format_response: bool = False, + output_generation_logits: bool = False, ): """ Query the Triton server synchronously and return a list of responses. @@ -190,6 +191,8 @@ def query_llm( no_repeat_ngram_size (int): no repeat ngram size. task_id (str): downstream task id if virtual tokens are used. init_timeout (flat): timeout for the connection. + openai_format_response: return response similar to OpenAI API format + output_generation_logits: return generation logits from model on PyTriton """ prompts = str_list2numpy(prompts) @@ -248,6 +251,9 @@ def query_llm( if end_strings is not None: inputs["end_strings"] = str_list2numpy(end_strings) + if output_generation_logits is not None: + inputs["output_generation_logits"] = np.full(prompts.shape, output_generation_logits, dtype=np.bool_) + with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: result_dict = client.infer_batch(**inputs) output_type = client.model_config.outputs[0].dtype @@ -269,6 +275,9 @@ def query_llm( "model": self.model_name, "choices": [{"text": str(sentences)}], } + # Convert gneration logits to a list to make it json serializable and add it to openai_response dict + if output_generation_logits: + openai_response["choices"][0]["generation_logits"] = result_dict["generation_logits"].tolist() return openai_response else: return sentences diff --git a/nemo/deploy/service/rest_model_api.py b/nemo/deploy/service/rest_model_api.py index fbc774883faa..64afea167295 100644 --- a/nemo/deploy/service/rest_model_api.py +++ b/nemo/deploy/service/rest_model_api.py @@ -8,8 +8,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import json import os from pathlib import Path import requests @@ -19,6 +17,7 @@ from pydantic_settings import BaseSettings from nemo.deploy.nlp import NemoQueryLLM +from nemo.utils import logging class TritonSettings(BaseSettings): @@ -29,14 +28,13 @@ class TritonSettings(BaseSettings): def __init__(self): super(TritonSettings, self).__init__() try: - with open(os.path.join(Path.cwd(), 'nemo/deploy/service/config.json')) as config: - config_json = json.load(config) - self._triton_service_port = config_json["triton_service_port"] - self._triton_service_ip = config_json["triton_service_ip"] - self._triton_request_timeout = config_json["triton_request_timeout"] - self._openai_format_response = config_json["openai_format_response"] + self._triton_service_port = int(os.environ.get('TRITON_PORT', 8080)) + self._triton_service_ip = os.environ.get('TRITON_HTTP_ADDRESS', '0.0.0.0') + self._triton_request_timeout = int(os.environ.get('TRITON_REQUEST_TIMEOUT', 60)) + self._openai_format_response = os.environ.get('OPENAI_FORMAT_RESPONSE', 'False').lower() == 'true' + self._output_generation_logits = os.environ.get('OUTPUT_GENERATION_LOGITS', 'False').lower() == 'true' except Exception as error: - print("An exception occurred:", error) + logging.error("An exception occurred trying to retrieve set args in TritonSettings class. Error:", error) return @property @@ -54,11 +52,17 @@ def triton_request_timeout(self): @property def openai_format_response(self): """ - Retuns the response from Triton server in OpenAI compatible formar if set to True, - default set in config.json is false. + Retuns the response from Triton server in OpenAI compatible format if set to True. """ return self._openai_format_response + @property + def output_generation_logits(self): + """ + Retuns the generation logits along with text in Triton server output if set to True. + """ + return self._output_generation_logits + app = FastAPI() triton_settings = TritonSettings() @@ -70,19 +74,27 @@ class CompletionRequest(BaseModel): max_tokens: int = 512 temperature: float = 1.0 top_p: float = 0.0 - n: int = 1 + top_k: int = 1 stream: bool = False stop: str | None = None frequency_penalty: float = 1.0 -@app.get("/triton_health") +@app.get("/v1/health") +def health_check(): + return {"status": "ok"} + + +@app.get("/v1/triton_health") async def check_triton_health(): """ This method exposes endpoint "/triton_health" which can be used to verify if Triton server is accessible while running the REST or FastAPI application. - Verify by running: curl http://service_http_address:service_port/triton_health and the returned status should inform if the server is accessible. + Verify by running: curl http://service_http_address:service_port/v1/triton_health and the returned status should inform if the server is accessible. """ - triton_url = f"triton_settings.triton_service_ip:str(triton_settings.triton_service_port)/v2/health/ready" + triton_url = ( + f"http://{triton_settings.triton_service_ip}:{str(triton_settings.triton_service_port)}/v2/health/ready" + ) + logging.info(f"Attempting to connect to Triton server at: {triton_url}") try: response = requests.get(triton_url, timeout=5) if response.status_code == 200: @@ -101,11 +113,13 @@ def completions_v1(request: CompletionRequest): output = nq.query_llm( prompts=[request.prompt], max_output_len=request.max_tokens, - top_k=request.n, + # when these below params are passed as None + top_k=request.top_k, top_p=request.top_p, temperature=request.temperature, init_timeout=triton_settings.triton_request_timeout, openai_format_response=triton_settings.openai_format_response, + output_generation_logits=triton_settings.output_generation_logits, ) if triton_settings.openai_format_response: return output @@ -114,5 +128,5 @@ def completions_v1(request: CompletionRequest): "output": output[0][0], } except Exception as error: - print("An exception occurred:", error) + logging.error("An exception occurred with the post request to /v1/completions/ endpoint:", error) return {"error": "An exception occurred"} diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 08b0b822cad4..a1e6cb0e03c4 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -180,6 +180,8 @@ def export( reduce_fusion: bool = True, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None, + gather_context_logits: Optional[bool] = False, + gather_generation_logits: Optional[bool] = False, ): """ Exports nemo checkpoints to TensorRT-LLM. @@ -218,6 +220,8 @@ def export( reduce_fusion (bool): enables fusing extra kernels after custom TRT-LLM allReduce fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type. fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type. + gather_context_logits (Optional[bool]): if True, enables gather_context_logits while building trtllm engine. Default: False + gather_generation_logits (Optional[bool]): if True, enables gather_generation_logits while building trtllm engine. Default: False """ if n_gpus is not None: warnings.warn( @@ -495,6 +499,8 @@ def get_transformer_config(nemo_model_config): multiple_profiles=multiple_profiles, gpt_attention_plugin=gpt_attention_plugin, gemm_plugin=gemm_plugin, + gather_context_logits=gather_context_logits, + gather_generation_logits=gather_generation_logits, ) tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model") @@ -688,6 +694,7 @@ def forward( prompt_embeddings_checkpoint_path: str = None, streaming: bool = False, output_log_probs: bool = False, + output_generation_logits: bool = False, **sampling_kwargs, ): """ @@ -706,6 +713,7 @@ def forward( task_ids (List(str)): list of the task ids for the prompt tables. prompt_embeddings_table (List(float)): prompt embeddings table. prompt_embeddings_checkpoint_path (str): path for the nemo checkpoint for the prompt embedding table. + output_generation_logits (bool): if True returns generation_logits in the outout of generate method. sampling_kwargs: Additional kwargs to set in the SamplingConfig. """ @@ -784,6 +792,7 @@ def forward( no_repeat_ngram_size=no_repeat_ngram_size, output_log_probs=output_log_probs, multiprocessed_env=multiprocessed_env, + output_generation_logits=output_generation_logits, **sampling_kwargs, ) else: @@ -862,16 +871,21 @@ def get_triton_input(self): Tensor(name="no_repeat_ngram_size", shape=(-1,), dtype=np.single, optional=True), Tensor(name="task_id", shape=(-1,), dtype=bytes, optional=True), Tensor(name="lora_uids", shape=(-1,), dtype=bytes, optional=True), + Tensor(name="output_generation_logits", shape=(-1,), dtype=np.bool_, optional=False), ) return inputs @property def get_triton_output(self): - outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),) + outputs = ( + Tensor(name="outputs", shape=(-1,), dtype=bytes), + Tensor(name="generation_logits", shape=(-1,), dtype=np.single), + ) return outputs @batch def triton_infer_fn(self, **inputs: np.ndarray): + output_dict = {} try: infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))} if "max_output_len" in inputs: @@ -898,14 +912,20 @@ def triton_infer_fn(self, **inputs: np.ndarray): if "lora_uids" in inputs: lora_uids = np.char.decode(inputs.pop("lora_uids").astype("bytes"), encoding="utf-8") infer_input["lora_uids"] = lora_uids[0].tolist() + if "output_generation_logits" in inputs: + infer_input["output_generation_logits"] = inputs.pop("output_generation_logits")[0][0] - output_texts = self.forward(**infer_input) - output = cast_output(output_texts, np.bytes_) + if infer_input["output_generation_logits"]: + output_texts, generation_logits = self.forward(**infer_input) + output_dict["generation_logits"] = np.array(generation_logits.cpu().numpy()) + else: + output_texts = self.forward(**infer_input) + output_dict["outputs"] = cast_output(output_texts, np.bytes_) except Exception as error: err_msg = "An error occurred: {0}".format(str(error)) - output = cast_output([err_msg], np.bytes_) + output_dict["outputs"] = cast_output([err_msg], np.bytes_) - return {"outputs": output} + return output_dict @batch def triton_infer_fn_streaming(self, **inputs: np.ndarray): diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index db1aec0f5a55..b0e134ab0c35 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -161,7 +161,7 @@ def convert_model_to_trt_llm_ckpt( or nemo_model_config.get("layernorm_zero_centered_gamma", False), "tp_size": training_tp_size, "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"] + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu", "openai-gelu"] and (decoder_type == "gptnext" or is_mcore), "num_attention_heads": num_attention_heads, "num_kv_heads": num_kv_heads, @@ -336,7 +336,7 @@ def dist_model_to_trt_llm_ckpt( "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", "tp_size": tp_size, "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"], + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu", "openai-gelu"], "num_attention_heads": nemo_model_config["num_attention_heads"], "num_kv_heads": nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), "convert_on_device": True, diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 4be2d42ebe4d..b2b761483700 100755 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -54,6 +54,8 @@ def build_and_save_engine( gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", reduce_fusion: bool = False, + gather_context_logits: bool = False, + gather_generation_logits: bool = False, ): architecture = "LLaMAForCausalLM" if model_config.architecture == "LlamaForCausalLM" else model_config.architecture try: @@ -96,8 +98,8 @@ def build_and_save_engine( 'max_num_tokens': max_num_tokens, 'opt_num_tokens': opt_num_tokens, 'max_prompt_embedding_table_size': max_prompt_embedding_table_size, - 'gather_context_logits': False, - 'gather_generation_logits': False, + 'gather_context_logits': gather_context_logits, + 'gather_generation_logits': gather_generation_logits, 'strongly_typed': False, 'builder_opt': None, 'use_refit': use_refit, @@ -118,14 +120,6 @@ def build_and_save_engine( build_config.lora_config = lora_config model = model_cls.from_config(model_config) - if not model_config.bias and model_config.architecture == 'GPTForCausalLM': - # NOTE: GPT models in megatron-core that set bias=False sets the bias false globally - # whereas bias=False in TRTLLM GPT models sets it false everywhere except - # LayerNorm. This change makes TRTLLM's implementation match megatron-core. - for name, module in model.named_modules(): - if isinstance(module, tensorrt_llm.layers.normalization.LayerNorm): - module.bias = None - module.register_parameter('bias', None) model = optimize_model( model, use_parallel_embedding=model_config.use_parallel_embedding, diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index bd7b8abd5f9e..ef67c918290f 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -647,6 +647,7 @@ def generate( streaming: bool = False, output_log_probs=False, multiprocessed_env=False, + output_generation_logits=False, **sampling_kwargs, ) -> Optional[List[List[str]]]: """Generate the output sequence from the input sequence. @@ -692,6 +693,7 @@ def generate( multiprocessed_env=multiprocessed_env, **sampling_kwargs, ) + assert outputs is not None if tensorrt_llm.mpi_rank() != 0: return None @@ -705,8 +707,8 @@ def generate( for b in range(output_ids.shape[0]) ] - if output_log_probs: - return output_lines_list, log_probs + if output_generation_logits: + return output_lines_list, outputs['generation_logits'] return output_lines_list diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 91d3b3f936d0..e01a2d5e5765 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -14,8 +14,8 @@ from typing import Union -from lightning_fabric.plugins.environments import slurm -from pytorch_lightning import plugins as _pl_plugins +from lightning.fabric.plugins.environments import slurm +from lightning.pytorch import plugins as _pl_plugins # This is here to import it once, which improves the speed of launch when in debug-mode from nemo.utils.import_utils import safe_import diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 1bee71e26e17..ea6fa4c4d226 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: - from lightning_fabric.utilities.types import Optimizable + from lightning.fabric.utilities.types import Optimizable from megatron.core.model_parallel_config import ModelParallelConfig diff --git a/nemo/lightning/base.py b/nemo/lightning/base.py index b6ba14726818..3b0b1c0c7234 100644 --- a/nemo/lightning/base.py +++ b/nemo/lightning/base.py @@ -19,7 +19,7 @@ import torch import torch.distributed -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from torch import nn diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index 9cf686464417..9cb685a096fa 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -19,7 +19,7 @@ from typing import List, Literal, Optional import torch -from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper +from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper from torch.utils.data import DataLoader, Dataset diff --git a/nemo/lightning/fabric/conversion.py b/nemo/lightning/fabric/conversion.py index 9ad713ec5261..d1c7affe3f40 100644 --- a/nemo/lightning/fabric/conversion.py +++ b/nemo/lightning/fabric/conversion.py @@ -15,10 +15,10 @@ from functools import singledispatch from typing import Any, TypeVar -from lightning_fabric import plugins as fl_plugins -from lightning_fabric import strategies as fl_strategies -from pytorch_lightning import plugins as pl_plugins -from pytorch_lightning import strategies as pl_strategies +from lightning.fabric import plugins as fl_plugins +from lightning.fabric import strategies as fl_strategies +from lightning.pytorch import plugins as pl_plugins +from lightning.pytorch import strategies as pl_strategies T = TypeVar('T') FabricT = TypeVar('FabricT') @@ -39,8 +39,8 @@ def to_fabric(obj: Any) -> Any: NotImplementedError: If no converter is registered for the object's type. Example: - >>> from pytorch_lightning.strategies import Strategy as PLStrategy - >>> from lightning_fabric.strategies import Strategy as FabricStrategy + >>> from lightning.pytorch.strategies import Strategy as PLStrategy + >>> from lightning.fabric.strategies import Strategy as FabricStrategy >>> from nemo.lightning.fabric.conversion import to_fabric >>> >>> # Define a custom PyTorch Lightning strategy @@ -70,7 +70,7 @@ def to_fabric(obj: Any) -> Any: f"No Fabric converter registered for {type(obj).__name__}. " f"To register a new conversion, use the @to_fabric.register decorator:\n\n" f"from nemo.lightning.fabric.conversion import to_fabric\n" - f"from lightning_fabric import strategies as fl_strategies\n\n" + f"from lightning.fabric import strategies as fl_strategies\n\n" f"@to_fabric.register({type(obj).__name__})\n" f"def _{type(obj).__name__.lower()}_converter(obj: {type(obj).__name__}) -> fl_strategies.Strategy:\n" f" return fl_strategies.SomeStrategy(\n" diff --git a/nemo/lightning/fabric/fabric.py b/nemo/lightning/fabric/fabric.py index 60eb518a1e42..7d604de749d6 100644 --- a/nemo/lightning/fabric/fabric.py +++ b/nemo/lightning/fabric/fabric.py @@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Optional, Protocol, Sequence, Type, TypeVar, Union, runtime_checkable import fiddle as fdl -import lightning_fabric as lb -import pytorch_lightning as pl +import lightning.fabric as lb +import lightning.pytorch as pl from torch import nn from typing_extensions import Self, override diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index 723b48b6b357..58bf5f5ca9f9 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Generator, Literal, TypeVar import torch -from lightning_fabric.plugins.precision import MixedPrecision +from lightning.fabric.plugins.precision import MixedPrecision from torch import nn from torch.optim import Optimizer diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py index 575f69a58caf..30a03504060f 100644 --- a/nemo/lightning/fabric/strategies.py +++ b/nemo/lightning/fabric/strategies.py @@ -29,21 +29,21 @@ ) import torch -from lightning_fabric.accelerators import CPUAccelerator -from lightning_fabric.accelerators.accelerator import Accelerator -from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout -from lightning_fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO -from lightning_fabric.plugins.precision import Precision -from lightning_fabric.strategies import DDPStrategy -from lightning_fabric.strategies.strategy import _validate_keys_for_strict_loading -from lightning_fabric.utilities.types import _PATH, _Stateful +from lightning.fabric.accelerators import CPUAccelerator +from lightning.fabric.accelerators.accelerator import Accelerator +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning.fabric.plugins.precision import Precision +from lightning.fabric.strategies import DDPStrategy +from lightning.fabric.strategies.strategy import _validate_keys_for_strict_loading +from lightning.fabric.utilities.types import _PATH, _Stateful +from lightning.pytorch import LightningDataModule +from lightning.pytorch.loops.fetchers import _DataFetcher +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from lightning.pytorch.utilities.combined_loader import CombinedLoader from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning import LightningDataModule -from pytorch_lightning.loops.fetchers import _DataFetcher -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO -from pytorch_lightning.utilities.combined_loader import CombinedLoader from torch import Tensor, nn from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.nn import Module diff --git a/nemo/lightning/io/api.py b/nemo/lightning/io/api.py index be9372f2e79b..869ec6e613cb 100644 --- a/nemo/lightning/io/api.py +++ b/nemo/lightning/io/api.py @@ -16,7 +16,7 @@ from typing import Callable, Optional, Type, overload import fiddle as fdl -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, load from nemo.lightning.io.pl import TrainerContext diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index e699f15565bd..a38be6ee8f0a 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -18,9 +18,9 @@ from pathlib import Path, PosixPath, PurePath, WindowsPath from typing import Generic, Optional, Tuple, TypeVar -import pytorch_lightning as pl +import lightning.pytorch as pl from filelock import FileLock, Timeout -from pytorch_lightning.trainer.states import TrainerFn +from lightning.pytorch.trainer.states import TrainerFn from nemo.lightning.ckpt_utils import ckpt_to_context_subdir diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index 10ed52b136c2..788697887e39 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -17,12 +17,12 @@ from pathlib import Path from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from lightning_fabric.plugins import CheckpointIO -from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO -from lightning_fabric.utilities.cloud_io import get_filesystem -from lightning_fabric.utilities.types import _PATH +from lightning.fabric.plugins import CheckpointIO +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.types import _PATH from megatron.core.dist_checkpointing.serialization import ( get_default_load_sharded_strategy, get_default_save_sharded_strategy, diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 6a3138b1da29..0f84f3be0a23 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -42,13 +42,12 @@ import torch import torch.distributed +from lightning.pytorch.utilities import move_data_to_device from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel as McoreDDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.transformer_config import TransformerConfig -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import move_data_to_device from torch import Tensor, nn from typing_extensions import override @@ -58,7 +57,7 @@ STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]] if TYPE_CHECKING: - import pytorch_lightning as pl + import lightning.pytorch as pl @runtime_checkable @@ -836,7 +835,7 @@ def add(self, *callbacks) -> "CallbackConnector": """ _pl_callback = None try: - import pytorch_lightning as pl + import lightning.pytorch as pl _pl_callback = pl.Callback except ImportError: diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index a901a3a8842a..79f622ebc6a8 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -18,10 +18,10 @@ from pathlib import Path from typing import List, Optional, Union -import lightning_fabric as fl -import pytorch_lightning as pl -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint -from pytorch_lightning.loggers import Logger, TensorBoardLogger, WandbLogger +import lightning.fabric as fl +import lightning.pytorch as pl +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint +from lightning.pytorch.loggers import Logger, TensorBoardLogger, WandbLogger from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.callbacks import ModelCheckpoint diff --git a/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py b/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py index 391666fb8f32..320140d76f3a 100644 --- a/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py +++ b/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py @@ -15,8 +15,8 @@ from functools import cache import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.utils import check_param_hashes_across_dp_replicas -from pytorch_lightning.callbacks.callback import Callback from nemo.lightning import io from nemo.utils import logging diff --git a/nemo/lightning/pytorch/callbacks/debugging.py b/nemo/lightning/pytorch/callbacks/debugging.py index 5f6e722ef89b..135e8e486837 100644 --- a/nemo/lightning/pytorch/callbacks/debugging.py +++ b/nemo/lightning/pytorch/callbacks/debugging.py @@ -14,9 +14,9 @@ from typing import Callable, Dict, List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule from nemo.utils import logging diff --git a/nemo/lightning/pytorch/callbacks/garbage_collection.py b/nemo/lightning/pytorch/callbacks/garbage_collection.py index ba4d378ee893..90e122f6d3e4 100644 --- a/nemo/lightning/pytorch/callbacks/garbage_collection.py +++ b/nemo/lightning/pytorch/callbacks/garbage_collection.py @@ -15,7 +15,7 @@ import gc from typing import Any -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.utils import logging diff --git a/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py b/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py index fc4312e2ff84..58173642527a 100644 --- a/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py +++ b/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py @@ -13,12 +13,12 @@ # limitations under the License. from dataclasses import asdict, dataclass, fields -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.callbacks.callback import Callback from megatron.core import ModelParallelConfig from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks.callback import Callback from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import TransformerLayerTPOverlapCfg from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy, ParallelismConfig diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py index 5b2ee1d46e11..2813bd141a7a 100644 --- a/nemo/lightning/pytorch/callbacks/memory_profiler.py +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -15,7 +15,7 @@ import os import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from torch.utils.viz._cycles import warn_tensor_cycles from nemo.lightning import io diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index b384976d82bd..ca4ca08cab08 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -19,12 +19,12 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union -import pytorch_lightning +import lightning.pytorch import torch from _weakref import proxy -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint -from pytorch_lightning.callbacks.model_checkpoint import _is_local_file_protocol -from pytorch_lightning.utilities import rank_zero_info +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint +from lightning.pytorch.callbacks.model_checkpoint import _is_local_file_protocol +from lightning.pytorch.utilities import rank_zero_info from nemo.lightning.ckpt_utils import ckpt_to_dir from nemo.lightning.io.pl import TrainerContext @@ -312,7 +312,7 @@ def _del_model_without_trainer(self, filepath: str) -> None: if torch.distributed.is_initialized(): torch.distributed.barrier() - def _ema_callback(self, trainer: 'pytorch_lightning.Trainer'): + def _ema_callback(self, trainer: 'lightning.pytorch.Trainer'): from nemo.collections.common.callbacks import EMA ema_callback = None @@ -393,7 +393,7 @@ def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barri except: return - def file_exists(self, filepath: str, trainer: "pytorch_lightning.Trainer", check_dist_ckpt: bool = True) -> bool: + def file_exists(self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True) -> bool: """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) return trainer.strategy.broadcast(exists) @@ -432,7 +432,7 @@ def _link_checkpoint(self, trainer: "pl.Trainer", filepath: str, linkpath: str, linkpath = ckpt_to_dir(linkpath) super()._link_checkpoint(trainer, filepath, linkpath) - def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None: from nemo.utils.get_rank import is_global_rank_zero # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. @@ -499,7 +499,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) finalize_fn() def _get_finalize_save_checkpoint_callback( - self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int + self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int ): """Creates a callback that can be used to finalize async (and sync) ckpt saves.""" @@ -534,7 +534,7 @@ def _cb(): return _cb - def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str, override_async=False) -> None: + def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False) -> None: """Performs checkpoint removal. With async save, `self._remove_checkpoint` is called before the checkpoint diff --git a/nemo/lightning/pytorch/callbacks/model_transform.py b/nemo/lightning/pytorch/callbacks/model_transform.py index 64602b501ac3..b3c3310aa30f 100644 --- a/nemo/lightning/pytorch/callbacks/model_transform.py +++ b/nemo/lightning/pytorch/callbacks/model_transform.py @@ -15,7 +15,7 @@ from functools import wraps from typing import Any, Callable, Optional, TypeVar -import pytorch_lightning as pl +import lightning.pytorch as pl from torch import nn from nemo.utils import logging @@ -85,7 +85,7 @@ def _maybe_apply_transform(self, trainer): def apply_transform(self, trainer): self.model_transform(trainer.model) - from pytorch_lightning.utilities import model_summary + from lightning.pytorch.utilities import model_summary logging.info( f"After applying model_transform:\n" f"{model_summary.summarize(trainer.lightning_module, max_depth=1)}" diff --git a/nemo/lightning/pytorch/callbacks/moe_token_drop.py b/nemo/lightning/pytorch/callbacks/moe_token_drop.py index 10483dca5096..b0c7ff7999eb 100644 --- a/nemo/lightning/pytorch/callbacks/moe_token_drop.py +++ b/nemo/lightning/pytorch/callbacks/moe_token_drop.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.callbacks.callback import Callback from megatron.core import ModelParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy diff --git a/nemo/lightning/pytorch/callbacks/nsys.py b/nemo/lightning/pytorch/callbacks/nsys.py index 2a5707d3166c..f350eae40730 100644 --- a/nemo/lightning/pytorch/callbacks/nsys.py +++ b/nemo/lightning/pytorch/callbacks/nsys.py @@ -15,7 +15,7 @@ from typing import List, Optional import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from nemo.utils import logging from nemo.utils.get_rank import get_rank diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 5336615a4a38..3790b7788419 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -18,12 +18,12 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn -from lightning_fabric.utilities.types import _PATH -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO -from pytorch_lightning.trainer.states import TrainerFn +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from lightning.pytorch.trainer.states import TrainerFn from typing_extensions import override from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME diff --git a/nemo/lightning/pytorch/callbacks/preemption.py b/nemo/lightning/pytorch/callbacks/preemption.py index 69ac378ed698..98b59a9da0d0 100644 --- a/nemo/lightning/pytorch/callbacks/preemption.py +++ b/nemo/lightning/pytorch/callbacks/preemption.py @@ -18,8 +18,8 @@ from typing import Optional import torch -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.trainer.trainer import Trainer from nemo.lightning.io.mixin import IOMixin from nemo.utils import logging diff --git a/nemo/lightning/pytorch/callbacks/progress_bar.py b/nemo/lightning/pytorch/callbacks/progress_bar.py index 6912c3fc57d4..f3c3c4555bac 100644 --- a/nemo/lightning/pytorch/callbacks/progress_bar.py +++ b/nemo/lightning/pytorch/callbacks/progress_bar.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.callbacks.progress import TQDMProgressBar -from pytorch_lightning.callbacks.progress.tqdm_progress import _update_n +from lightning.pytorch.callbacks.progress import TQDMProgressBar +from lightning.pytorch.callbacks.progress.tqdm_progress import _update_n class MegatronProgressBar(TQDMProgressBar): diff --git a/nemo/lightning/pytorch/callbacks/progress_printer.py b/nemo/lightning/pytorch/callbacks/progress_printer.py index d32f7d70cbdd..12d05ed2950c 100644 --- a/nemo/lightning/pytorch/callbacks/progress_printer.py +++ b/nemo/lightning/pytorch/callbacks/progress_printer.py @@ -15,9 +15,9 @@ from collections import defaultdict from typing import Any +from lightning.pytorch.callbacks.progress import ProgressBar +from lightning.pytorch.utilities.types import STEP_OUTPUT from megatron.core.num_microbatches_calculator import get_num_microbatches -from pytorch_lightning.callbacks.progress import ProgressBar -from pytorch_lightning.utilities.types import STEP_OUTPUT from typing_extensions import override diff --git a/nemo/lightning/pytorch/optim/base.py b/nemo/lightning/pytorch/optim/base.py index 1d476142941a..fec3b7c118a4 100644 --- a/nemo/lightning/pytorch/optim/base.py +++ b/nemo/lightning/pytorch/optim/base.py @@ -17,8 +17,8 @@ from copy import deepcopy from typing import List, Optional -import pytorch_lightning as L -from pytorch_lightning.utilities.types import OptimizerLRScheduler +import lightning.pytorch as L +from lightning.pytorch.utilities.types import OptimizerLRScheduler from torch.optim import Optimizer from nemo.lightning.io.mixin import IOMixin diff --git a/nemo/lightning/pytorch/optim/megatron.py b/nemo/lightning/pytorch/optim/megatron.py index 7ac413d4544f..9f9d2029be9e 100644 --- a/nemo/lightning/pytorch/optim/megatron.py +++ b/nemo/lightning/pytorch/optim/megatron.py @@ -15,7 +15,7 @@ import inspect from typing import Callable, List, Optional -import pytorch_lightning as pl +import lightning.pytorch as pl from megatron.core.distributed import finalize_model_grads from megatron.core.optimizer import OptimizerConfig from megatron.core.utils import get_model_config diff --git a/nemo/lightning/pytorch/optim/pytorch.py b/nemo/lightning/pytorch/optim/pytorch.py index 9d773917e4f4..ccd03f563ef8 100644 --- a/nemo/lightning/pytorch/optim/pytorch.py +++ b/nemo/lightning/pytorch/optim/pytorch.py @@ -14,8 +14,8 @@ from typing import Callable, List, Optional -import pytorch_lightning as pl -import pytorch_lightning as L +import lightning.pytorch as pl +import lightning.pytorch as L from torch.optim import Optimizer from torch.optim.optimizer import ParamsT diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 024e2577c868..479e442d5ccb 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -16,7 +16,7 @@ import logging from typing import List, Literal, Optional -import pytorch_lightning as pl +import lightning.pytorch as pl from torch.utils.data import DataLoader from nemo.lightning.megatron_parallel import MegatronStep diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 5c318b59e54a..830978ba11e7 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -16,9 +16,8 @@ from dataclasses import dataclass, fields from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union -import pytorch_lightning as pl import torch -from pytorch_lightning.plugins.precision import Precision +from lightning.pytorch.plugins.precision import Precision from torch.nn import Module from torch.optim import Optimizer diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 83d5781c0dde..4c5a165c2d8d 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -17,14 +17,14 @@ from pathlib import Path from typing import Any, Dict, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from lightning_fabric.plugins import CheckpointIO -from lightning_fabric.strategies.fsdp import _get_sharded_state_dict_context +from lightning.fabric.plugins import CheckpointIO +from lightning.fabric.strategies.fsdp import _get_sharded_state_dict_context +from lightning.pytorch.strategies.fsdp import FSDPStrategy as PLFSDPStrategy +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.types import STEP_OUTPUT from megatron.core.transformer.transformer_layer import TransformerLayer -from pytorch_lightning.strategies.fsdp import FSDPStrategy as PLFSDPStrategy -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.distributed.checkpoint.state_dict import ( # get_state_dict, StateDictOptions, get_optimizer_state_dict, diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index c62a90313b45..870cc0aaaddd 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -35,20 +35,20 @@ cast, ) -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.distributed -from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment -from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device +from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment +from lightning.fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop +from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher +from lightning.pytorch.overrides.distributed import _sync_module_states +from lightning.pytorch.strategies.ddp import DDPStrategy +from lightning.pytorch.trainer.states import RunningStage, TrainerFn +from lightning.pytorch.utilities.types import STEP_OUTPUT from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop -from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher -from pytorch_lightning.overrides.distributed import _sync_module_states -from pytorch_lightning.strategies.ddp import DDPStrategy -from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities.types import STEP_OUTPUT from torch import nn from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.nn.parallel import DistributedDataParallel diff --git a/nemo/lightning/pytorch/strategies/utils.py b/nemo/lightning/pytorch/strategies/utils.py index 43a5a9243aa5..4f5a78419d6d 100644 --- a/nemo/lightning/pytorch/strategies/utils.py +++ b/nemo/lightning/pytorch/strategies/utils.py @@ -17,15 +17,14 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from lightning_fabric.plugins import ClusterEnvironment +from lightning.fabric.plugins import ClusterEnvironment +from lightning.pytorch.callbacks import TQDMProgressBar from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedBase, ShardedObject, ShardedTensor from megatron.core.dist_checkpointing.strategies.torch import sharded_tensor_to_torch_sharded_tensor from megatron.core.transformer.utils import _get_extra_state_offsets -from pytorch_lightning.callbacks import TQDMProgressBar -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index c97c59ef524d..701c1cde4eaf 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -16,9 +16,9 @@ from copy import deepcopy import fiddle as fdl -import pytorch_lightning as pl -from pytorch_lightning.loops import _TrainingEpochLoop -from pytorch_lightning.loops.fetchers import _DataFetcher +import lightning.pytorch as pl +from lightning.pytorch.loops import _TrainingEpochLoop +from lightning.pytorch.loops.fetchers import _DataFetcher from typing_extensions import Self from nemo.lightning.fabric.conversion import to_fabric diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index 412ca8665b84..7b534646731c 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -18,8 +18,8 @@ from pathlib import Path, PosixPath, WindowsPath from typing import Optional, Union -import lightning_fabric as fl -import pytorch_lightning as pl +import lightning.fabric as fl +import lightning.pytorch as pl from nemo.lightning import io from nemo.lightning.base import NEMO_MODELS_CACHE diff --git a/nemo/lightning/run/plugins.py b/nemo/lightning/run/plugins.py index 9d2936e567ec..4130272bfc90 100644 --- a/nemo/lightning/run/plugins.py +++ b/nemo/lightning/run/plugins.py @@ -20,9 +20,9 @@ import nemo_run as run import yaml +from lightning.pytorch import Callback +from lightning.pytorch.loggers import WandbLogger from nemo_run.core.serialization.yaml import YamlSerializer -from pytorch_lightning import Callback -from pytorch_lightning.loggers import WandbLogger from nemo.lightning.pytorch.callbacks import NsysCallback, PreemptionCallback from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy @@ -262,6 +262,8 @@ class PerfEnvPlugin(run.Plugin): enable_vboost: bool = False def get_vboost_srun_cmd(self, nodes, job_dir): + "Create the vboost `sudo nvidia-smi boost-slider --vboost 1` command" + import shlex vboost_cmd = " ".join( @@ -281,12 +283,13 @@ def get_vboost_srun_cmd(self, nodes, job_dir): return vboost_cmd def setup(self, task: run.Partial | run.Script, executor: run.Executor): + """Enable the performance environment settings""" if task.trainer.strategy.__fn_or_cls__ == MegatronStrategy: # Force program order kernel launch for TP, CP overlap tp_size = task.trainer.strategy.tensor_model_parallel_size cp_size = task.trainer.strategy.context_parallel_size - if tp_size > 1 and cp_size > 1: + if tp_size > 1 or cp_size > 1: executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" # Set LayerNorm SM margin to support the overlap with LayerNorm kernel diff --git a/nemo/utils/callbacks/cuda_graph.py b/nemo/utils/callbacks/cuda_graph.py index c78196934108..b44006828963 100644 --- a/nemo/utils/callbacks/cuda_graph.py +++ b/nemo/utils/callbacks/cuda_graph.py @@ -37,15 +37,15 @@ from types import MethodType from typing import Any, Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.loops.optimization.automatic import ClosureResult -from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection, _ResultMetric -from pytorch_lightning.utilities import CombinedLoader, rank_zero_info -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import STEP_OUTPUT +from lightning.pytorch import LightningModule +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loops.optimization.automatic import ClosureResult +from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection, _ResultMetric +from lightning.pytorch.utilities import CombinedLoader, rank_zero_info +from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch.nn.parallel import DistributedDataParallel __all__ = ["CUDAGraphCallback"] @@ -431,8 +431,8 @@ def on_save_checkpoint( Called when saving a checkpoint to give you a chance to store anything else you might want to save. Args: - trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. - pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance. + trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance. + pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance. checkpoint: the checkpoint dictionary that will be saved. """ # Since we've add bound method to optimizer and lr_scheduler, it can lead to more diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 091075488878..b78ec9b4ac0f 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -19,12 +19,12 @@ from time import time from typing import Any, Dict, Optional, Union -import pytorch_lightning as pl -from lightning_fabric.plugins import CheckpointIO -from lightning_fabric.utilities.cloud_io import get_filesystem -from lightning_fabric.utilities.types import _PATH -from pytorch_lightning import Callback -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO +import lightning.pytorch as pl +from lightning.fabric.plugins import CheckpointIO +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch import Callback +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from nemo.utils import logging diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index dc1da9ce1875..8fe3beaaa985 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -19,14 +19,13 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -import pytorch_lightning +import lightning.pytorch import torch from _weakref import proxy - -from lightning_fabric.utilities.cloud_io import get_filesystem -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol -from pytorch_lightning.trainer import call -from pytorch_lightning.utilities import rank_zero_info +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol +from lightning.pytorch.trainer import call +from lightning.pytorch.utilities import rank_zero_info from nemo.collections.common.callbacks import EMA from nemo.utils import logging @@ -357,7 +356,7 @@ def _del_model_without_trainer(self, filepath: str) -> None: except: logging.info(f"Tried to remove checkpoint: {filepath} but failed.") - def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]: + def _ema_callback(self, trainer: 'lightning.pytorch.Trainer') -> Optional[EMA]: ema_callback = None for callback in trainer.callbacks: if isinstance(callback, EMA): @@ -506,12 +505,12 @@ def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barri except: return - def file_exists(self, filepath: str, trainer: "pytorch_lightning.Trainer", check_dist_ckpt: bool = True) -> bool: + def file_exists(self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True) -> bool: """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) return trainer.strategy.broadcast(exists) - def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None: # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) @@ -552,7 +551,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self._drop_optimizer_states(trainer, filepath, storage_options) def _get_finalize_save_checkpoint_callback( - self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int + self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int ): """Creates a callback that can be used to finalize async (and sync) ckpt saves.""" @@ -585,7 +584,7 @@ def _cb(): return _cb - def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str, override_async=False) -> None: + def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False) -> None: """Performs checkpoint removal or deferred removal. With async save, `self._remove_checkpoint` is called before the checkpoint diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index e9b5f95022f3..178fe94cee7c 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -16,7 +16,7 @@ import sys import torch -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from nemo.utils import logging @@ -24,7 +24,7 @@ class PreemptionCallback(Callback): """ PreemptionCallback class creates a callback that checks for preemption during training at the end of every step. - Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the current state in a checkpoint as *last.ckpt. + Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the current state in a checkpoint as *last.ckpt. (to be able to start from the same step without wasting any compute while resuming the next time). PreemptionCallback is always enabled by default via the arg create_preemption_callback under ExpManagerConfig. To disable please pass @@ -47,7 +47,7 @@ def interrupted(self): def on_train_start(self, trainer, pl_module): """ - Defines custom handlers at the beginning of training to be executed when the + Defines custom handlers at the beginning of training to be executed when the preemption signal is received. """ diff --git a/nemo/utils/callbacks/s3_checkpoint_io.py b/nemo/utils/callbacks/s3_checkpoint_io.py index 7a9f984fee1b..4a48198311a2 100644 --- a/nemo/utils/callbacks/s3_checkpoint_io.py +++ b/nemo/utils/callbacks/s3_checkpoint_io.py @@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, Optional, Union import torch -from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO from nemo.utils import logging from nemo.utils.s3_utils import ( diff --git a/nemo/utils/cloud.py b/nemo/utils/cloud.py index 7245567d636c..d565028bdf8c 100644 --- a/nemo/utils/cloud.py +++ b/nemo/utils/cloud.py @@ -17,8 +17,8 @@ from time import sleep import wget -from pytorch_lightning.plugins.environments import LightningEnvironment -from pytorch_lightning.strategies import DDPStrategy, StrategyRegistry +from lightning.pytorch.plugins.environments import LightningEnvironment +from lightning.pytorch.strategies import DDPStrategy, StrategyRegistry from nemo.utils import logging @@ -105,7 +105,10 @@ def initialize_sagemaker() -> None: """ StrategyRegistry.register( - name='smddp', strategy=SageMakerDDPStrategy, process_group_backend="smddp", find_unused_parameters=False, + name='smddp', + strategy=SageMakerDDPStrategy, + process_group_backend="smddp", + find_unused_parameters=False, ) def _install_system_libraries() -> None: diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index b512bc57cbab..04c43c46d247 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -26,18 +26,18 @@ from shutil import copy, move from typing import Any, Collection, Dict, List, Optional, Tuple, Union -import pytorch_lightning +import lightning.pytorch import torch from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd +from lightning.pytorch.callbacks import Callback, ModelCheckpoint +from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.callbacks.timer import Interval, Timer +from lightning.pytorch.loggers import MLFlowLogger, NeptuneLogger, TensorBoardLogger, WandbLogger +from lightning.pytorch.loops import _TrainingEpochLoop +from lightning.pytorch.strategies.ddp import DDPStrategy +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.callbacks import Callback, ModelCheckpoint -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks.timer import Interval, Timer -from pytorch_lightning.loggers import MLFlowLogger, NeptuneLogger, TensorBoardLogger, WandbLogger -from pytorch_lightning.loops import _TrainingEpochLoop -from pytorch_lightning.strategies.ddp import DDPStrategy -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.common.callbacks import EMA from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION @@ -343,7 +343,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) self._on_batch_end("validation_step_timing in s", trainer, pl_module) -def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: +def exp_manager(trainer: 'lightning.pytorch.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: """ exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir, @@ -362,7 +362,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo resume_if_exists is set to True, creating the version folders is ignored. Args: - trainer (pytorch_lightning.Trainer): The lightning trainer. + trainer (lightning.pytorch.Trainer): The lightning trainer. cfg (DictConfig, dict): Can have the following keys: - explicit_log_dir (str, Path): Can be used to override exp_dir/name/version folder creation. Defaults to @@ -680,7 +680,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo return log_dir -def error_checks(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None): +def error_checks(trainer: 'lightning.pytorch.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None): """ Checks that the passed trainer is compliant with NeMo and exp_manager's passed configuration. Checks that: - Throws error when hydra has changed the working directory. This causes issues with lightning's DDP @@ -728,7 +728,7 @@ def _filter_out_unfinished_checkpoints(checkpoint_paths: Collection[Union[Path, def check_resume( - trainer: 'pytorch_lightning.Trainer', + trainer: 'lightning.pytorch.Trainer', log_dir: str, resume_if_exists: bool = False, resume_past_end: bool = False, @@ -886,7 +886,7 @@ def check_resume( def check_explicit_log_dir( - trainer: 'pytorch_lightning.Trainer', explicit_log_dir: Union[Path, str], exp_dir: str, name: str, version: str + trainer: 'lightning.pytorch.Trainer', explicit_log_dir: Union[Path, str], exp_dir: str, name: str, version: str ) -> Tuple[Path, str, str, str]: """Checks that the passed arguments are compatible with explicit_log_dir. @@ -917,7 +917,7 @@ def check_explicit_log_dir( def get_log_dir( - trainer: 'pytorch_lightning.Trainer', + trainer: 'lightning.pytorch.Trainer', exp_dir: str = None, name: str = None, version: str = None, @@ -1025,7 +1025,7 @@ def get_git_diff(): def configure_loggers( - trainer: 'pytorch_lightning.Trainer', + trainer: 'lightning.pytorch.Trainer', exp_dir: [Path, str], log_dir: [Path, str], name: str, @@ -1136,7 +1136,7 @@ def resume_start(self, checkpoint_path=None) -> None: def configure_checkpointing( - trainer: 'pytorch_lightning.Trainer', + trainer: 'lightning.pytorch.Trainer', log_dir: Path, name: str, resume: bool, @@ -1257,12 +1257,12 @@ def _check_time_remaining(self, trainer: "pl.Trainer") -> None: monitor_candidates = checkpoint_callback._monitor_candidates(trainer) checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) # Throw this exception to signal to Lightning to terminate gracefully. - from pytorch_lightning.utilities.exceptions import _TunerExitException + from lightning.pytorch.utilities.exceptions import _TunerExitException raise _TunerExitException() -def configure_no_restart_validation_training_loop(trainer: pytorch_lightning.Trainer) -> None: +def configure_no_restart_validation_training_loop(trainer: lightning.pytorch.Trainer) -> None: if type(trainer.fit_loop.epoch_loop) != _TrainingEpochLoop: warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) return diff --git a/nemo/utils/lightning_logger_patch.py b/nemo/utils/lightning_logger_patch.py index 1b21ce3b1ae5..1528146c64b5 100644 --- a/nemo/utils/lightning_logger_patch.py +++ b/nemo/utils/lightning_logger_patch.py @@ -15,7 +15,7 @@ import logging as _logging from logging.handlers import MemoryHandler -import pytorch_lightning as pl +import lightning.pytorch as pl HANDLERS = {} PATCHED = False diff --git a/nemo/utils/loggers/clearml_logger.py b/nemo/utils/loggers/clearml_logger.py index 4e2063705b4f..c7c3945ad853 100644 --- a/nemo/utils/loggers/clearml_logger.py +++ b/nemo/utils/loggers/clearml_logger.py @@ -19,11 +19,11 @@ from typing import Any, List, Literal, Mapping, Optional, Union import pandas as pd +from lightning.pytorch.callbacks import Checkpoint +from lightning.pytorch.loggers import Logger +from lightning.pytorch.utilities.parsing import AttributeDict from lightning_utilities.core.apply_func import apply_to_collection from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning.callbacks import Checkpoint -from pytorch_lightning.loggers import Logger -from pytorch_lightning.utilities.parsing import AttributeDict from torch import Tensor from nemo.utils import logging diff --git a/nemo/utils/loggers/dllogger.py b/nemo/utils/loggers/dllogger.py index cdeef63b75f7..871d7ee3f7a2 100644 --- a/nemo/utils/loggers/dllogger.py +++ b/nemo/utils/loggers/dllogger.py @@ -17,11 +17,11 @@ from pathlib import Path from typing import Optional +from lightning.pytorch.loggers import Logger +from lightning.pytorch.utilities import rank_zero_only +from lightning.pytorch.utilities.parsing import AttributeDict from lightning_utilities.core.apply_func import apply_to_collection from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning.loggers import Logger -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.parsing import AttributeDict from nemo.utils import logging @@ -34,7 +34,7 @@ HAVE_DLLOGGER = False try: - from lightning_fabric.utilities.logger import _convert_params, _flatten_dict, _sanitize_callable_params + from lightning.fabric.utilities.logger import _convert_params, _flatten_dict, _sanitize_callable_params PL_LOGGER_UTILITIES = True except (ImportError, ModuleNotFoundError): diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index e8020f244821..adca2283f577 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -1,8 +1,8 @@ cloudpickle fiddle hydra-core>1.3,<=1.3.2 +lightning>2.2.1 omegaconf<=2.3 -pytorch-lightning>2.2.1 torchmetrics>=0.11.0 transformers>=4.45.0 wandb diff --git a/requirements/requirements_tts.txt b/requirements/requirements_tts.txt index 0d499feb3b1f..6d20e0f2250f 100644 --- a/requirements/requirements_tts.txt +++ b/requirements/requirements_tts.txt @@ -11,3 +11,5 @@ nltk pandas pypinyin pypinyin-dict +seaborn + diff --git a/scripts/checkpoint_averaging/average_model_checkpoints.py b/scripts/checkpoint_averaging/average_model_checkpoints.py index 06c522f1e192..ce88bba9716b 100644 --- a/scripts/checkpoint_averaging/average_model_checkpoints.py +++ b/scripts/checkpoint_averaging/average_model_checkpoints.py @@ -60,7 +60,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf, open_dict diff --git a/scripts/checkpoint_averaging/megatron_checkpoint_averaging.py b/scripts/checkpoint_averaging/megatron_checkpoint_averaging.py index 59f02a117da4..7b964fd7bade 100755 --- a/scripts/checkpoint_averaging/megatron_checkpoint_averaging.py +++ b/scripts/checkpoint_averaging/megatron_checkpoint_averaging.py @@ -35,8 +35,8 @@ import sys import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector from nemo.core import ModelPT @@ -60,7 +60,10 @@ def main(): help='A list of Python file names to "from FILE import *" (Needed when some classes were defined in __main__ of a script)', ) parser.add_argument( - '--class_path', type=str, default='', help='A path to class "module.submodule.class" (if given)', + '--class_path', + type=str, + default='', + help='A path to class "module.submodule.class" (if given)', ) args = parser.parse_args() diff --git a/scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py b/scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py index b87f7e028cdb..b35fb201865e 100644 --- a/scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py @@ -25,9 +25,9 @@ from collections import OrderedDict import torch +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -158,7 +158,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) diff --git a/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py index ec048e4b6f19..335989309791 100644 --- a/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import AutoModelForCausalLM from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -128,7 +128,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_bert_nemo_to_hf.py b/scripts/checkpoint_converters/convert_bert_nemo_to_hf.py index e970ea29fca2..0ec5cc1e474b 100644 --- a/scripts/checkpoint_converters/convert_bert_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_bert_nemo_to_hf.py @@ -26,7 +26,7 @@ import torch import torch.nn.functional as F -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import AutoTokenizer, BertConfig, BertModel from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel @@ -207,10 +207,16 @@ def convert_config(ref_config, hf_state_dict): def get_args(): parser = ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, required=True, help="Path to .nemo file", + "--input_name_or_path", + type=str, + required=True, + help="Path to .nemo file", ) parser.add_argument( - "--output_path", type=str, required=True, help="Output HF model path", + "--output_path", + type=str, + required=True, + help="Output HF model path", ) args = parser.parse_args() diff --git a/scripts/checkpoint_converters/convert_chatglm_hf_to_nemo.py b/scripts/checkpoint_converters/convert_chatglm_hf_to_nemo.py index 363e4de09ef7..2545181ce968 100644 --- a/scripts/checkpoint_converters/convert_chatglm_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_chatglm_hf_to_nemo.py @@ -25,8 +25,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModel, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -126,7 +126,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) @@ -211,7 +211,11 @@ def convert(args): qkv_bias = torch.cat((qkv_bias, q[i * heads_per_group : (i + 1) * heads_per_group, :])) qkv_bias = torch.cat((qkv_bias, k[i : i + 1, :])) qkv_bias = torch.cat((qkv_bias, v[i : i + 1, :])) - qkv_bias = qkv_bias.reshape([head_size * (head_num + 2 * num_query_groups),]) + qkv_bias = qkv_bias.reshape( + [ + head_size * (head_num + 2 * num_query_groups), + ] + ) if mcore_gpt: qkv_weights_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.weight' diff --git a/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py b/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py index 5a8e52ee8be5..241e4254a9be 100644 --- a/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import AutoModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -126,7 +126,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups # 32 / 2 = 16 qkv_total_dim = head_num + 2 * num_query_groups # 32 + 2 * 2 = 36 diff --git a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py index 2b8156ad4b26..c47444534604 100644 --- a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py @@ -38,9 +38,9 @@ from argparse import ArgumentParser import torch +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPModel from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel diff --git a/scripts/checkpoint_converters/convert_falcon_hf_to_nemo.py b/scripts/checkpoint_converters/convert_falcon_hf_to_nemo.py index ae8885f4de93..8a880a290484 100644 --- a/scripts/checkpoint_converters/convert_falcon_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_falcon_hf_to_nemo.py @@ -32,7 +32,7 @@ import time from typing import Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import yaml from omegaconf import OmegaConf @@ -83,11 +83,11 @@ def get_new_key(old_key): def load_falcon_config(args) -> FalconConfig: - """ Helper utility to load FalconConfig. + """Helper utility to load FalconConfig. Legacy Falcon-7B and Falcon-40B are not compatible with `transformers.FalconConfig` and `transformers.FalconModel`. need to manually set the config values - and force to `falcon` model type. + and force to `falcon` model type. """ config = FalconConfig.from_pretrained(args.input_name_or_path) if config.model_type == 'RefinedWeb': diff --git a/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py b/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py index da8f15b92649..cc1d99b6d1c6 100644 --- a/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import AutoModelForCausalLM from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py b/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py index 35039f8d02e9..61443a3bcb28 100644 --- a/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py +++ b/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py @@ -17,8 +17,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy diff --git a/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py b/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py index 4eb8cb6330ca..44de38497b44 100644 --- a/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py @@ -27,8 +27,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_llama_hf_to_nemo_load.py b/scripts/checkpoint_converters/convert_llama_hf_to_nemo_load.py index 42d3e77ce4c8..75bd0d0ab6ed 100644 --- a/scripts/checkpoint_converters/convert_llama_hf_to_nemo_load.py +++ b/scripts/checkpoint_converters/convert_llama_hf_to_nemo_load.py @@ -28,8 +28,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_llama_hf_to_nemo_save_dict.py b/scripts/checkpoint_converters/convert_llama_hf_to_nemo_save_dict.py index f7096996e5b1..4a8a409a88fd 100644 --- a/scripts/checkpoint_converters/convert_llama_hf_to_nemo_save_dict.py +++ b/scripts/checkpoint_converters/convert_llama_hf_to_nemo_save_dict.py @@ -27,8 +27,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py index a3c40676a980..87b7151aa961 100644 --- a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py @@ -17,8 +17,8 @@ from collections import OrderedDict import torch +from lightning.pytorch import Trainer from omegaconf import open_dict -from pytorch_lightning import Trainer from transformers import AutoModelForCausalLM, LlamaTokenizer, LlamaTokenizerFast, convert_slow_tokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -26,7 +26,7 @@ from nemo.utils import logging """ -Script to convert a llama2 checkpoint in nemo (mcore path) into a HuggingFace checkpoint. +Script to convert a llama checkpoint in nemo (mcore path) into a HuggingFace checkpoint. This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder. 1) Generate only HF weights from a nemo file: @@ -37,13 +37,21 @@ 2) Generate the full HF model folder + python convert_llama_nemo_to_hf.py \ + --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ + --output_path /path/to/pytorch_model.bin \ + --hf_input_path /path/to/input_hf_folder \ + --hf_output_path /path/to/output_hf_folder + +3) Generate the full HF model folder with a custom tokenizer + python convert_llama_nemo_to_hf.py \ --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ --output_path /path/to/pytorch_model.bin \ --hf_input_path /path/to/input_hf_folder \ --hf_output_path /path/to/output_hf_folder \ - --input_tokenizer /path/to/tokenizer \ - --hf_output_tokenizer /path/to/output_tokenizer \ + --input_tokenizer /path/to/custom_nemo_tokenizer.model \ + --hf_output_tokenizer /path/to/output_tokenizer Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Llama2 70b). However this option makes the conversion script significantly slower. @@ -143,7 +151,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups @@ -246,21 +254,25 @@ def replace_hf_weights_and_tokenizer( nemo_exported = torch.load(weights_file) if tokenizer_path: - tokenizer = LlamaTokenizer.from_pretrained( - tokenizer_path, - local_files_only=True, - legacy=False, - ) - tmp_tokenizer = convert_slow_tokenizer.convert_slow_tokenizer(tokenizer) - fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer) - tokenizer_length = len(fast_tokenizer) - model.resize_token_embeddings(tokenizer_length) + try: + tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_path, + local_files_only=True, + legacy=False, + ) + tmp_tokenizer = convert_slow_tokenizer.convert_slow_tokenizer(tokenizer) + fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer) + tokenizer_length = len(fast_tokenizer) + model.resize_token_embeddings(tokenizer_length) + except: + tokenizer = None + logging.warning("Could not load custom tokenizer, proceeding with default tokenizer") model.load_state_dict(nemo_exported) model.save_pretrained(output_hf_path) logging.info(f"Full HF model saved to {output_hf_path}") - if tokenizer_path: + if tokenizer_path and (tokenizer is not None): fast_tokenizer.save_pretrained(output_hf_tokenizer) tokenizer.save_pretrained(output_hf_tokenizer) logging.info(f"Tokenizer saved to {output_hf_tokenizer}") diff --git a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py index 4bceb250999f..3cf5bbd4acf9 100644 --- a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py @@ -29,9 +29,9 @@ import torch import torch.nn +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py b/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py index b8c30a1b929d..1f0a31076f8e 100644 --- a/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py @@ -25,7 +25,7 @@ import torch import torch.nn -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -134,7 +134,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = model.cfg.get('kv_channels', hidden_size // head_num) + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py index 36e4c0c2c3ea..a75c6876e70a 100644 --- a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py @@ -30,9 +30,9 @@ import megatron.core.parallel_state as parallel_state import torch import torch.nn +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py b/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py index 2bac2eaad616..eb934803f164 100644 --- a/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py @@ -26,7 +26,7 @@ import megatron.core.parallel_state as parallel_state import torch import torch.nn -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -137,7 +137,7 @@ def convert(in_file, precision=None) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_mpt_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mpt_hf_to_nemo.py index e7d81f709092..d4a450a8e046 100644 --- a/scripts/checkpoint_converters/convert_mpt_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mpt_hf_to_nemo.py @@ -56,7 +56,7 @@ import argparse import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import yaml from omegaconf import OmegaConf @@ -68,7 +68,11 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, default=None, required=True, help="Path to Huggingface MPT checkpoints", + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to Huggingface MPT checkpoints", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") parser.add_argument( diff --git a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py index fc0f660cbd42..392e3628ccdb 100644 --- a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py @@ -19,7 +19,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import LlamaTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import LlamaConverter diff --git a/scripts/checkpoint_converters/convert_qwen2_hf_to_nemo.py b/scripts/checkpoint_converters/convert_qwen2_hf_to_nemo.py index 223c7af50843..b472a7e5c6f3 100644 --- a/scripts/checkpoint_converters/convert_qwen2_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_qwen2_hf_to_nemo.py @@ -25,8 +25,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import Qwen2ForCausalLM, Qwen2Tokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py index 6080499ffdf8..968caade917c 100644 --- a/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import Qwen2ForCausalLM, Qwen2Tokenizer, Qwen2TokenizerFast, convert_slow_tokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -142,7 +142,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_starcoder2_hf_to_nemo.py b/scripts/checkpoint_converters/convert_starcoder2_hf_to_nemo.py index fc898c797a9e..862777cf52a8 100644 --- a/scripts/checkpoint_converters/convert_starcoder2_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_starcoder2_hf_to_nemo.py @@ -28,9 +28,9 @@ import torch import torch.nn +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -168,7 +168,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) diff --git a/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py index 4b65533b74ec..c418a714be0a 100644 --- a/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py @@ -25,7 +25,7 @@ import torch import torch.nn -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -141,7 +141,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_starcoder_hf_to_nemo.py b/scripts/checkpoint_converters/convert_starcoder_hf_to_nemo.py index e600c65e6de1..6b9f30ab427b 100644 --- a/scripts/checkpoint_converters/convert_starcoder_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_starcoder_hf_to_nemo.py @@ -52,7 +52,7 @@ import os from typing import Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import yaml from omegaconf import OmegaConf diff --git a/scripts/checkpoint_converters/quantize_model_to_nf4.py b/scripts/checkpoint_converters/quantize_model_to_nf4.py index db3a48aaa16d..8fbaeb875f7a 100644 --- a/scripts/checkpoint_converters/quantize_model_to_nf4.py +++ b/scripts/checkpoint_converters/quantize_model_to_nf4.py @@ -16,7 +16,7 @@ from typing import List import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from torch import nn from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel diff --git a/scripts/confidence_ensembles/build_ensemble.py b/scripts/confidence_ensembles/build_ensemble.py index 4c05e2e4ff3f..dfb3793b42f4 100644 --- a/scripts/confidence_ensembles/build_ensemble.py +++ b/scripts/confidence_ensembles/build_ensemble.py @@ -80,8 +80,8 @@ from typing import Dict, List, Optional, Tuple import joblib +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl from omegaconf import MISSING, DictConfig, OmegaConf from sklearn.linear_model import LogisticRegression from sklearn.metrics import confusion_matrix @@ -215,7 +215,12 @@ class BuildEnsembleConfig: preserve_frame_confidence=True, exclude_blank=True, aggregation="mean", - method_cfg=ConfidenceMethodConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",), + method_cfg=ConfidenceMethodConfig( + name="entropy", + entropy_type="renyi", + alpha=0.25, + entropy_norm="lin", + ), ) ) temperature: float = 1.0 @@ -499,7 +504,12 @@ def find_best_confidence( dev_features = np.array(list(zip(*cur_dev_confidences))) dev_labels = np.array(dev_labels) pipe, score = train_model_selection( - training_features, training_labels, dev_features, dev_labels, tune_lr, tune_lr_config, + training_features, + training_labels, + dev_features, + dev_labels, + tune_lr, + tune_lr_config, ) if max_score < score: max_score = score @@ -513,7 +523,7 @@ def find_best_confidence( @hydra_runner(config_name="BuildEnsembleConfig", schema=BuildEnsembleConfig) def main(cfg: BuildEnsembleConfig): # silencing all messages from nemo/ptl to avoid dumping tons of configs to the stdout - logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL) + logging.getLogger('lightning.pytorch').setLevel(logging.CRITICAL) logging.getLogger('nemo_logger').setLevel(logging.CRITICAL) LOG.info(f'Build ensemble config:\n{OmegaConf.to_yaml(cfg)}') diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index e3394726fa1c..154ffc90dc9c 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -419,13 +419,14 @@ def nemo_deploy(argv): LOGGER.info("Triton deploy function will be called.") nm.deploy() + nm.run() except Exception as error: LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) return try: LOGGER.info("Model serving on Triton is will be started.") - if args.start_rest_service == "True": + if args.start_rest_service: try: LOGGER.info("REST service will be started.") uvicorn.run( diff --git a/scripts/diffusion_model_lora_merge/merge_lora_weights_into_base_model.py b/scripts/diffusion_model_lora_merge/merge_lora_weights_into_base_model.py index 57d9964cad3d..a80d9d2639e3 100644 --- a/scripts/diffusion_model_lora_merge/merge_lora_weights_into_base_model.py +++ b/scripts/diffusion_model_lora_merge/merge_lora_weights_into_base_model.py @@ -16,7 +16,7 @@ from typing import Any, Dict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference diff --git a/scripts/export.py b/scripts/export.py index acfd3e3e3450..6e0b9b72e15b 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -30,8 +30,8 @@ import sys import torch +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer import nemo from nemo.core import ModelPT diff --git a/scripts/nemo_legacy_import/nlp_checkpoint_port.py b/scripts/nemo_legacy_import/nlp_checkpoint_port.py index b7541ffdb8cd..058f9e072f5f 100644 --- a/scripts/nemo_legacy_import/nlp_checkpoint_port.py +++ b/scripts/nemo_legacy_import/nlp_checkpoint_port.py @@ -30,7 +30,7 @@ import logging import sys -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf, open_dict from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector diff --git a/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py b/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py index 334b3415a93b..3e96186552a5 100644 --- a/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py @@ -14,7 +14,7 @@ import os -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import ( MegatronGPTPromptLearningModel, diff --git a/scripts/nlp_language_modeling/hf_t5-v1_1_to_nemo.py b/scripts/nlp_language_modeling/hf_t5-v1_1_to_nemo.py index 6a94e8f501bb..2361e000ef7e 100644 --- a/scripts/nlp_language_modeling/hf_t5-v1_1_to_nemo.py +++ b/scripts/nlp_language_modeling/hf_t5-v1_1_to_nemo.py @@ -53,8 +53,8 @@ from argparse import ArgumentParser import torch +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from transformers import AutoTokenizer, T5ForConditionalGeneration from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model diff --git a/scripts/nlp_language_modeling/merge_lora_weights/merge.py b/scripts/nlp_language_modeling/merge_lora_weights/merge.py index 55d50502705c..3a6d110997ba 100644 --- a/scripts/nlp_language_modeling/merge_lora_weights/merge.py +++ b/scripts/nlp_language_modeling/merge_lora_weights/merge.py @@ -33,8 +33,8 @@ from typing import Any, Dict, List import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py index ee32f69bf734..dd7c1a3656be 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py @@ -15,8 +15,8 @@ import os import torch +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer @@ -66,7 +66,10 @@ def main(cfg) -> None: save_restore_connector.model_extracted_dir = model_path model_cfg = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, + model_path, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, ) with open_dict(model_cfg): @@ -76,7 +79,10 @@ def main(cfg) -> None: model_cfg.activations_checkpoint_method = None model = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, + model_path, + trainer=trainer, + save_restore_connector=save_restore_connector, + override_config_path=model_cfg, ) # check whether the DDP is initialized diff --git a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py index 9c42ef6cca5b..7208867ff938 100644 --- a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py +++ b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py @@ -18,7 +18,7 @@ from pathlib import Path from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import MISSING, OmegaConf from sklearn.model_selection import ParameterGrid diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index 3d5eb5a4dbb1..8d215cbc14eb 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -20,7 +20,7 @@ from typing import Iterable, Literal import click -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from lhotse import compute_num_samples from omegaconf import OmegaConf diff --git a/tests/collections/asr/confidence/test_asr_confidence.py b/tests/collections/asr/confidence/test_asr_confidence.py index 015264a9debe..89beb61f50bf 100644 --- a/tests/collections/asr/confidence/test_asr_confidence.py +++ b/tests/collections/asr/confidence/test_asr_confidence.py @@ -19,8 +19,8 @@ import numpy as np import pytest +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.asr.models import ASRModel, EncDecCTCModelBPE, EncDecRNNTBPEModel from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig diff --git a/tests/collections/asr/test_asr_context_biasing.py b/tests/collections/asr/test_asr_context_biasing.py index 0fa76fdfb95d..b23b12655a8d 100644 --- a/tests/collections/asr/test_asr_context_biasing.py +++ b/tests/collections/asr/test_asr_context_biasing.py @@ -19,7 +19,7 @@ import numpy as np import pytest import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.asr.models import EncDecCTCModelBPE from nemo.collections.asr.parts import context_biasing @@ -105,25 +105,43 @@ def test_merge_alignment_with_ws_hyps(self, conformer_ctc_bpe_model): # ctc argmax predictions preds = np.array([120, 29, blank_idx, blank_idx]) pred_text, raw_text = context_biasing.merge_alignment_with_ws_hyps( - preds, asr_model, ws_results, decoder_type="ctc", blank_idx=blank_idx, + preds, + asr_model, + ws_results, + decoder_type="ctc", + blank_idx=blank_idx, ) assert raw_text == "gp" assert pred_text == "gpu" # rnnt token predictions preds = rnnt_utils.Hypothesis( - y_sequence=torch.tensor([120, 29]), score=0.0, timestep=torch.tensor([0, 1, 2, 3]), + y_sequence=torch.tensor([120, 29]), + score=0.0, + timestep=torch.tensor([0, 1, 2, 3]), ) pred_text, raw_text = context_biasing.merge_alignment_with_ws_hyps( - preds, asr_model, ws_results, decoder_type="rnnt", blank_idx=blank_idx, + preds, + asr_model, + ws_results, + decoder_type="rnnt", + blank_idx=blank_idx, ) assert raw_text == "gp" assert pred_text == "gpu" # rnnt empty token predictions - preds = rnnt_utils.Hypothesis(y_sequence=[], score=0.0, timestep=[],) + preds = rnnt_utils.Hypothesis( + y_sequence=[], + score=0.0, + timestep=[], + ) pred_text, raw_text = context_biasing.merge_alignment_with_ws_hyps( - preds, asr_model, ws_results, decoder_type="rnnt", blank_idx=blank_idx, + preds, + asr_model, + ws_results, + decoder_type="rnnt", + blank_idx=blank_idx, ) assert raw_text == "" assert pred_text == "gpu" diff --git a/tests/collections/asr/test_asr_interctc_models.py b/tests/collections/asr/test_asr_interctc_models.py index 8d5e4b0b689c..a8d7101033ab 100644 --- a/tests/collections/asr/test_asr_interctc_models.py +++ b/tests/collections/asr/test_asr_interctc_models.py @@ -13,8 +13,8 @@ # limitations under the License. from typing import Dict +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch from omegaconf import DictConfig, ListConfig @@ -68,7 +68,8 @@ def squeezeformer_encoder_config() -> Dict: class TestInterCTCLoss: @pytest.mark.unit @pytest.mark.parametrize( - "model_class", [EncDecCTCModel, EncDecHybridRNNTCTCModel], + "model_class", + [EncDecCTCModel, EncDecHybridRNNTCTCModel], ) @pytest.mark.parametrize( "encoder_config", @@ -241,10 +242,12 @@ def __getitem__(self, idx): trainer.fit( asr_model, train_dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), val_dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), ) required_metrics = ['final_loss'] if len(loss_weights) > 0 else [] @@ -264,7 +267,8 @@ def __getitem__(self, idx): trainer.test( asr_model, dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), ) required_metrics = [f'inter_ctc_loss_l{idx}' for idx in apply_at_layers] diff --git a/tests/collections/asr/test_asr_local_attn.py b/tests/collections/asr/test_asr_local_attn.py index 257dc0949af3..3013c0efbddf 100644 --- a/tests/collections/asr/test_asr_local_attn.py +++ b/tests/collections/asr/test_asr_local_attn.py @@ -15,8 +15,8 @@ import shutil import tempfile +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch from omegaconf import DictConfig @@ -89,10 +89,12 @@ def test_change_save_restore(self): @pytest.mark.unit @pytest.mark.parametrize( - "global_tokens", [0, 1, 4], + "global_tokens", + [0, 1, 4], ) @pytest.mark.parametrize( - "global_tokens_spacing", [1, 4], + "global_tokens_spacing", + [1, 4], ) def test_train(self, global_tokens, global_tokens_spacing): preprocessor_config = {'_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor'} @@ -178,15 +180,18 @@ def __getitem__(self, idx): trainer.fit( asr_model, train_dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), val_dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), ) trainer.test( asr_model, dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), ) diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index 98f733f1c568..18ee04e371e2 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -15,13 +15,13 @@ import os.path from typing import Any, Dict, Union +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.types import STEP_OUTPUT from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Callback, Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT from nemo.collections.common.callbacks import EMA from nemo.collections.common.callbacks.ema import EMAOptimizer @@ -349,7 +349,12 @@ class TestEMATrain: @pytest.mark.parametrize("validate_original_weights", [True, False]) @pytest.mark.run_only_on('GPU') def test_ema_run_cuda( - self, test_data_dir, precision, accumulate_grad_batches, validate_original_weights, tmpdir, + self, + test_data_dir, + precision, + accumulate_grad_batches, + validate_original_weights, + tmpdir, ): self.run_training_test( accumulate_grad_batches=accumulate_grad_batches, diff --git a/tests/collections/llm/common.py b/tests/collections/llm/common.py index 95b8bc0de584..c17243936bd1 100644 --- a/tests/collections/llm/common.py +++ b/tests/collections/llm/common.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from nemo import lightning as nl diff --git a/tests/collections/llm/gpt/model/megatron_ssm_pretraining.py b/tests/collections/llm/gpt/model/megatron_ssm_pretraining.py index d7ecaafaaf8c..55bea59d6274 100644 --- a/tests/collections/llm/gpt/model/megatron_ssm_pretraining.py +++ b/tests/collections/llm/gpt/model/megatron_ssm_pretraining.py @@ -16,9 +16,11 @@ ## There are no guarantees that this script is up-to-date with latest NeMo. import argparse + import torch +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import TensorBoardLogger + from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.api import train diff --git a/tests/collections/llm/lora_mistralai.py b/tests/collections/llm/lora_mistralai.py index 09a52668e3ee..0415569304ac 100644 --- a/tests/collections/llm/lora_mistralai.py +++ b/tests/collections/llm/lora_mistralai.py @@ -14,7 +14,7 @@ import argparse -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from megatron.core.optimizer import OptimizerConfig diff --git a/tests/collections/llm/megatron_gpt_pretraining.py b/tests/collections/llm/megatron_gpt_pretraining.py index a73b2a694c76..9722ba9d6c68 100644 --- a/tests/collections/llm/megatron_gpt_pretraining.py +++ b/tests/collections/llm/megatron_gpt_pretraining.py @@ -18,8 +18,8 @@ import argparse import torch +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import TensorBoardLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/tests/collections/llm/megatron_t5_finetuning.py b/tests/collections/llm/megatron_t5_finetuning.py index e8f4947c9674..976ad5c48053 100644 --- a/tests/collections/llm/megatron_t5_finetuning.py +++ b/tests/collections/llm/megatron_t5_finetuning.py @@ -18,8 +18,8 @@ import argparse import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm @@ -103,7 +103,7 @@ def get_args(): optimizer='adam', lr=2.0e-5, use_distributed_optimizer=False, - bf16=False, + bf16=True, weight_decay=0.1, ) opt = MegatronOptimizerModule( @@ -124,7 +124,7 @@ def get_args(): log_every_n_steps=1, limit_val_batches=2, val_check_interval=50, - plugins=nl.MegatronMixedPrecision(precision="32"), + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), ) if args.wandb_project is not None: diff --git a/tests/collections/llm/megatron_t5_pretraining.py b/tests/collections/llm/megatron_t5_pretraining.py index a5460be3d154..ad63ae88fb73 100644 --- a/tests/collections/llm/megatron_t5_pretraining.py +++ b/tests/collections/llm/megatron_t5_pretraining.py @@ -18,8 +18,8 @@ import argparse import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py index a5c2aa96fc03..92cffc2a35bb 100644 --- a/tests/collections/llm/test_mnist_model_nemo2.py +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -23,16 +23,16 @@ from pathlib import Path from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, TypeVar, Union +import lightning.pytorch as pl import megatron.core.num_microbatches_calculator import pytest -import pytorch_lightning as pl import torch import torch.distributed +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core import ModelParallelConfig, parallel_state from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.enums import ModelType from megatron.core.transformer.module import MegatronModule -from pytorch_lightning.loggers import TensorBoardLogger from torch import Tensor, nn from torch.utils.data import DataLoader from torchvision import transforms diff --git a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py index 8a6c1f993d28..da45d0e1fc38 100644 --- a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py +++ b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py @@ -23,16 +23,16 @@ from pathlib import Path from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, TypeVar, Union +import lightning.pytorch as pl import megatron.core.num_microbatches_calculator import pytest -import pytorch_lightning as pl import torch import torch.distributed +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core import ModelParallelConfig, parallel_state from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.enums import ModelType from megatron.core.transformer.module import MegatronModule -from pytorch_lightning.loggers import TensorBoardLogger from torch import Tensor, nn from torch.optim import Adam from torch.utils.data import DataLoader diff --git a/tests/collections/multimodal/test_speechllm_models.py b/tests/collections/multimodal/test_speechllm_models.py index 8698fed205ea..09149064b657 100644 --- a/tests/collections/multimodal/test_speechllm_models.py +++ b/tests/collections/multimodal/test_speechllm_models.py @@ -16,13 +16,13 @@ import tempfile from pathlib import Path +import lightning.pytorch as pl import numpy as np import pytest -import pytorch_lightning as pl import torch +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from megatron.core import parallel_state from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from nemo.collections.multimodal.speech_llm.models import modular_models from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import shift_tokens_by_multi_audios diff --git a/tests/collections/nlp/test_falcon_model.py b/tests/collections/nlp/test_falcon_model.py index 23430ad36300..62a4591092a9 100644 --- a/tests/collections/nlp/test_falcon_model.py +++ b/tests/collections/nlp/test_falcon_model.py @@ -14,8 +14,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/tests/collections/nlp/test_flash_attention.py b/tests/collections/nlp/test_flash_attention.py index f5585ddc1636..c8309b34b433 100644 --- a/tests/collections/nlp/test_flash_attention.py +++ b/tests/collections/nlp/test_flash_attention.py @@ -16,8 +16,8 @@ import pytest import torch +from lightning.pytorch.trainer.trainer import Trainer from megatron.core import ModelParallelConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.modules.common.megatron.attention import CoreAttention from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo diff --git a/tests/collections/nlp/test_gpt_eval.py b/tests/collections/nlp/test_gpt_eval.py index fb3f9fda5ac3..020185ec7385 100644 --- a/tests/collections/nlp/test_gpt_eval.py +++ b/tests/collections/nlp/test_gpt_eval.py @@ -16,7 +16,7 @@ import numpy as np import pytest -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam diff --git a/tests/collections/nlp/test_gpt_model.py b/tests/collections/nlp/test_gpt_model.py index 7b6c02f948a4..334167f3dcf8 100644 --- a/tests/collections/nlp/test_gpt_model.py +++ b/tests/collections/nlp/test_gpt_model.py @@ -16,8 +16,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index b404764e7eed..6da0f8c93cc0 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -14,9 +14,9 @@ import os import tempfile +import lightning.pytorch as pl import onnx import pytest -import pytorch_lightning as pl import torch import wget from omegaconf import DictConfig, OmegaConf diff --git a/tests/collections/nlp/test_pretrained_models_performance.py b/tests/collections/nlp/test_pretrained_models_performance.py index 82ff6ed103f1..b51f00681f57 100644 --- a/tests/collections/nlp/test_pretrained_models_performance.py +++ b/tests/collections/nlp/test_pretrained_models_performance.py @@ -17,8 +17,8 @@ from shutil import rmtree from unittest import TestCase +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl from omegaconf import OmegaConf import nemo.collections.nlp.models as models diff --git a/tests/collections/nlp/test_rampup_batch_size.py b/tests/collections/nlp/test_rampup_batch_size.py index c7efb5f57f4c..763dfaaf3c51 100644 --- a/tests/collections/nlp/test_rampup_batch_size.py +++ b/tests/collections/nlp/test_rampup_batch_size.py @@ -16,8 +16,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy diff --git a/tests/collections/nlp/test_retrieval_module.py b/tests/collections/nlp/test_retrieval_module.py index 426e393c85bf..381d009f0e02 100644 --- a/tests/collections/nlp/test_retrieval_module.py +++ b/tests/collections/nlp/test_retrieval_module.py @@ -16,7 +16,7 @@ import pytest import torch from einops import rearrange -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.modules.common.megatron.attention import ParallelChunkedCrossAttention from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType @@ -73,7 +73,13 @@ def setup_class(cls): MB_SIZE = 4 GB_SIZE = 8 SEED = 1234 - trainer = Trainer(strategy=NLPDDPStrategy(), devices=GPUS, accelerator='gpu', num_nodes=1, logger=None,) + trainer = Trainer( + strategy=NLPDDPStrategy(), + devices=GPUS, + accelerator='gpu', + num_nodes=1, + logger=None, + ) initialize_model_parallel_for_nemo( world_size=trainer.world_size, @@ -134,7 +140,9 @@ def test_cross_attn(self, model_parallel_config): dec_attn_mask = rearrange(hidden_mask, '(k n) b -> (b k) n', k=chunks) context_attn_mask = rearrange(context_mask, 'k r n b -> (b k) (r n)') enc_dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding, + source_mask=dec_attn_mask, + target_mask=context_attn_mask, + attn_mask_type=AttnMaskType.padding, ) enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :] diff --git a/tests/collections/nlp/test_retrieval_module_inference.py b/tests/collections/nlp/test_retrieval_module_inference.py index ccb426ce4ab1..a7da05340708 100644 --- a/tests/collections/nlp/test_retrieval_module_inference.py +++ b/tests/collections/nlp/test_retrieval_module_inference.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F from einops import rearrange -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.modules.common.megatron.attention import ParallelChunkedCrossAttention from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType @@ -73,7 +73,13 @@ def setup_class(cls): MB_SIZE = 4 GB_SIZE = 8 SEED = 1234 - trainer = Trainer(strategy=NLPDDPStrategy(), devices=GPUS, accelerator='gpu', num_nodes=1, logger=None,) + trainer = Trainer( + strategy=NLPDDPStrategy(), + devices=GPUS, + accelerator='gpu', + num_nodes=1, + logger=None, + ) initialize_model_parallel_for_nemo( world_size=trainer.world_size, @@ -176,15 +182,33 @@ def test_retrieval_encoder_inference(self, model_parallel_config): neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, :64]).abs().max().item() < 1e-5 - assert (out_gt[:, 0,] - out_2[:, 0]).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + 0, + ] + - out_2[:, 0] + ).abs().max().item() < 1e-2 out_test = encoder( retrieved_emb[:, :1], context_mask[:, :1], context_attn_mask=hidden_mask[:, :64], encoder_output=hidden_emb[:, :64], ) - assert (out_gt[:, 0,] - out_test[:, 0]).abs().max().item() < 1e-2 - assert (out_gt[:, 0,] - out_2[:, 0]).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + 0, + ] + - out_test[:, 0] + ).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + 0, + ] + - out_2[:, 0] + ).abs().max().item() < 1e-2 for i in range(64, 127): out_3 = encoder( @@ -207,7 +231,13 @@ def test_retrieval_encoder_inference(self, model_parallel_config): neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, 64:128]).abs().max().item() < 1e-5 - assert (out_gt[:, :2,] - out_3).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + :2, + ] + - out_3 + ).abs().max().item() < 1e-2 # test inference for i in range(128, 191): out_4 = encoder( @@ -231,7 +261,13 @@ def test_retrieval_encoder_inference(self, model_parallel_config): ) assert (encoder.encoder_output - hidden_emb[:, 128:192]).abs().max().item() < 1e-5 - assert (out_gt[:, :3,] - out_4).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + :3, + ] + - out_4 + ).abs().max().item() < 1e-2 out_2 = encoder( retrieved_emb[:, :2], @@ -263,7 +299,13 @@ def test_retrieval_encoder_inference(self, model_parallel_config): neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, 128:192]).abs().max().item() < 1e-5 - assert (out_gt[:, :3,] - out_4).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + :3, + ] + - out_4 + ).abs().max().item() < 1e-2 @pytest.mark.unit def test_cross_attn_inference(self, model_parallel_config): @@ -309,7 +351,9 @@ def get_attn_mask_3d(hidden_mask, context_mask, chunks): dec_attn_mask = rearrange(hidden_mask, '(k n) b -> (b k) n', k=chunks) context_attn_mask = rearrange(context_mask, 'k r n b -> (b k) (r n)') enc_dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding, + source_mask=dec_attn_mask, + target_mask=context_attn_mask, + attn_mask_type=AttnMaskType.padding, ) enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :] return enc_dec_attn_mask_3d diff --git a/tests/collections/nlp/test_retro_model.py b/tests/collections/nlp/test_retro_model.py index b96016c8d7ec..e91590915ba5 100644 --- a/tests/collections/nlp/test_retro_model.py +++ b/tests/collections/nlp/test_retro_model.py @@ -16,8 +16,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids diff --git a/tests/core/test_config_utils.py b/tests/core/test_config_utils.py index bb0a0f177dfb..9716fc160629 100644 --- a/tests/core/test_config_utils.py +++ b/tests/core/test_config_utils.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from typing import Any +import lightning.pytorch as ptl import pytest -import pytorch_lightning as ptl -from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.callbacks.early_stopping import EarlyStopping from nemo.core.config.pytorch_lightning import TrainerConfig from nemo.utils import config_utils @@ -126,7 +126,9 @@ def test_ptl_config(self): assert dataclass_subset is None @pytest.mark.unit - def test_early_stopping_config(self,): + def test_early_stopping_config( + self, + ): result = config_utils.assert_dataclass_signature_match(EarlyStopping, EarlyStoppingParams) signatures_match, cls_subset, dataclass_subset = result diff --git a/tests/core/test_dist_ckpt.py b/tests/core/test_dist_ckpt.py index 0a483c0f58ab..6c066d1856a2 100644 --- a/tests/core/test_dist_ckpt.py +++ b/tests/core/test_dist_ckpt.py @@ -17,11 +17,11 @@ from pathlib import Path from typing import Any, Dict +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch -from lightning_fabric.plugins import TorchCheckpointIO -from pytorch_lightning.demos.boring_classes import BoringModel +from lightning.fabric.plugins import TorchCheckpointIO +from lightning.pytorch.demos.boring_classes import BoringModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.utils.callbacks.dist_ckpt_io import ( diff --git a/tests/core/test_exp_manager.py b/tests/core/test_exp_manager.py index d4b1d37c1938..32d401b2051f 100644 --- a/tests/core/test_exp_manager.py +++ b/tests/core/test_exp_manager.py @@ -18,13 +18,13 @@ from pathlib import Path from typing import Any +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch +from lightning.pytorch import Callback +from lightning.pytorch.loops import _TrainingEpochLoop from omegaconf import OmegaConf from omegaconf.errors import OmegaConfBaseException -from pytorch_lightning import Callback -from pytorch_lightning.loops import _TrainingEpochLoop from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.constants import NEMO_ENV_VARNAME_VERSION diff --git a/tests/core/test_fault_tolerance.py b/tests/core/test_fault_tolerance.py index 5b4e0ecba4aa..f916a7b44454 100644 --- a/tests/core/test_fault_tolerance.py +++ b/tests/core/test_fault_tolerance.py @@ -13,8 +13,8 @@ # limitations under the License. import os +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl from nemo.utils.exp_manager import exp_manager diff --git a/tests/core/test_optimizers_schedulers.py b/tests/core/test_optimizers_schedulers.py index 5e5d1ee20c83..419db309a918 100644 --- a/tests/core/test_optimizers_schedulers.py +++ b/tests/core/test_optimizers_schedulers.py @@ -15,12 +15,12 @@ import math import random +import lightning.pytorch as pl import omegaconf import pytest -import pytorch_lightning as pl import torch import torch.optim -from pytorch_lightning.utilities import rank_zero_only +from lightning.pytorch.utilities import rank_zero_only from nemo.core import config, optim from nemo.core.optim.lr_scheduler import AVAILABLE_SCHEDULERS @@ -936,7 +936,13 @@ def train( enable_progress_bar=False, ) max_steps = optim.lr_scheduler.compute_max_steps( - max_epochs, accumulate_grad_batches, limit_train_batches, devices, dataset_len, batch_size, drop_last, + max_epochs, + accumulate_grad_batches, + limit_train_batches, + devices, + dataset_len, + batch_size, + drop_last, ) model = ExampleModel(batch_size, dataset_len, drop_last, max_steps) trainer.callbacks.append(Callback()) @@ -991,7 +997,13 @@ def train( dataset_len = random.randint(20, devices * 500) batch_size = random.randint(math.ceil(5.0 / devices), min(dataset_len // devices, 128)) train( - max_epochs, accumulate_grad_batches, limit_train_batches, devices, batch_size, dataset_len, drop_last, + max_epochs, + accumulate_grad_batches, + limit_train_batches, + devices, + batch_size, + dataset_len, + drop_last, ) @pytest.mark.unit diff --git a/tests/core/test_straggler_det.py b/tests/core/test_straggler_det.py index ee5222854889..1f938214d792 100644 --- a/tests/core/test_straggler_det.py +++ b/tests/core/test_straggler_det.py @@ -14,8 +14,8 @@ import sys +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch from omegaconf import OmegaConf diff --git a/tests/core_ptl/check_for_ranks.py b/tests/core_ptl/check_for_ranks.py index a1eae66790c4..dfbc05166c5a 100644 --- a/tests/core_ptl/check_for_ranks.py +++ b/tests/core_ptl/check_for_ranks.py @@ -16,9 +16,9 @@ import shutil import torch +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.utilities import rank_zero_only from nemo.core import ModelPT from nemo.utils import logging diff --git a/tests/core_ptl/check_manual_upload_to_hf_hub.py b/tests/core_ptl/check_manual_upload_to_hf_hub.py index f411ee72332c..912eabb805bf 100644 --- a/tests/core_ptl/check_manual_upload_to_hf_hub.py +++ b/tests/core_ptl/check_manual_upload_to_hf_hub.py @@ -14,7 +14,7 @@ import shutil from huggingface_hub import HfApi -from pytorch_lightning.utilities import rank_zero_only +from lightning.pytorch.utilities import rank_zero_only from nemo.core import ModelPT from nemo.utils import AppState, logging @@ -40,7 +40,9 @@ def load_model_from_unpacked_hf_dir(repo_id): def upload_model_as_single_nemo_file(model: ModelPT, repo_id, token): # Upload the model to HF Hub model.push_to_hf_hub( - repo_id=repo_id, pack_nemo_file=True, token=token, + repo_id=repo_id, + pack_nemo_file=True, + token=token, ) @@ -48,7 +50,9 @@ def upload_model_as_single_nemo_file(model: ModelPT, repo_id, token): def upload_model_as_single_nemo_file(model: ModelPT, repo_id, token): # Upload the model to HF Hub model.push_to_hf_hub( - repo_id=repo_id, pack_nemo_file=True, token=token, + repo_id=repo_id, + pack_nemo_file=True, + token=token, ) @@ -56,7 +60,9 @@ def upload_model_as_single_nemo_file(model: ModelPT, repo_id, token): def upload_model_as_unpacked_files(model: ModelPT, repo_id, token): # Upload the model to HF Hub model.push_to_hf_hub( - repo_id=repo_id, pack_nemo_file=False, token=token, + repo_id=repo_id, + pack_nemo_file=False, + token=token, ) diff --git a/tests/core_ptl/test_ptl_stateless_timer.py b/tests/core_ptl/test_ptl_stateless_timer.py index 25f354a23c0d..5cfbbda39bbf 100644 --- a/tests/core_ptl/test_ptl_stateless_timer.py +++ b/tests/core_ptl/test_ptl_stateless_timer.py @@ -17,8 +17,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.core import ModelPT from nemo.utils import logging diff --git a/tests/lightning/_fabric/test_conversion.py b/tests/lightning/_fabric/test_conversion.py index e690557ec2eb..e97e766c86a7 100644 --- a/tests/lightning/_fabric/test_conversion.py +++ b/tests/lightning/_fabric/test_conversion.py @@ -13,10 +13,10 @@ # limitations under the License. import pytest -from lightning_fabric import plugins as fl_plugins -from lightning_fabric import strategies as fl_strategies -from pytorch_lightning import plugins as pl_plugins -from pytorch_lightning import strategies as pl_strategies +from lightning.fabric import plugins as fl_plugins +from lightning.fabric import strategies as fl_strategies +from lightning.pytorch import plugins as pl_plugins +from lightning.pytorch import strategies as pl_strategies from nemo import lightning as nl from nemo.lightning.fabric.conversion import to_fabric diff --git a/tests/lightning/_io/test_api.py b/tests/lightning/_io/test_api.py index a4d458cef17b..e0aaac1a6aa2 100644 --- a/tests/lightning/_io/test_api.py +++ b/tests/lightning/_io/test_api.py @@ -19,7 +19,7 @@ import fiddle as fdl import pytest import yaml -from pytorch_lightning.loggers import TensorBoardLogger +from lightning.pytorch.loggers import TensorBoardLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/tests/lightning/pytorch/callbacks/test_model_checkpoint.py b/tests/lightning/pytorch/callbacks/test_model_checkpoint.py index 802f2b28c25c..edaa8a6f4ec9 100644 --- a/tests/lightning/pytorch/callbacks/test_model_checkpoint.py +++ b/tests/lightning/pytorch/callbacks/test_model_checkpoint.py @@ -17,12 +17,12 @@ from pathlib import Path from typing import Iterator, Optional, Sequence, Tuple +import lightning.pytorch as pl import megatron import pytest -import pytorch_lightning as pl import torch +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from megatron.core import ModelParallelConfig, parallel_state -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch import Tensor import nemo.lightning as nl diff --git a/tests/lightning/pytorch/callbacks/test_model_transform.py b/tests/lightning/pytorch/callbacks/test_model_transform.py index c59a82895125..cfae55cf99a9 100644 --- a/tests/lightning/pytorch/callbacks/test_model_transform.py +++ b/tests/lightning/pytorch/callbacks/test_model_transform.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl from torch import nn from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform diff --git a/tests/lightning/pytorch/callbacks/test_peft.py b/tests/lightning/pytorch/callbacks/test_peft.py index 49a6aa0784aa..fb6728acee8f 100644 --- a/tests/lightning/pytorch/callbacks/test_peft.py +++ b/tests/lightning/pytorch/callbacks/test_peft.py @@ -15,7 +15,7 @@ from unittest.mock import MagicMock, call, patch import torch.nn as nn -from pytorch_lightning.trainer.states import TrainerFn +from lightning.pytorch.trainer.states import TrainerFn from nemo.collections.llm import fn from nemo.lightning.pytorch.callbacks.peft import PEFT, WrappedAdapterIO from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO diff --git a/tests/lightning/pytorch/callbacks/test_preemption.py b/tests/lightning/pytorch/callbacks/test_preemption.py index 4152f7fcce59..802d898c5a2b 100644 --- a/tests/lightning/pytorch/callbacks/test_preemption.py +++ b/tests/lightning/pytorch/callbacks/test_preemption.py @@ -17,7 +17,7 @@ import pytest import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback diff --git a/tests/lightning/test_dist_ckpt.py b/tests/lightning/test_dist_ckpt.py index 886b1085ed55..107d15061792 100644 --- a/tests/lightning/test_dist_ckpt.py +++ b/tests/lightning/test_dist_ckpt.py @@ -21,8 +21,8 @@ def set_env(): from pathlib import Path +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch import nemo.lightning as nl diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index a5a5ec32c886..8a63a92f0ee6 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -19,8 +19,8 @@ from unittest.mock import patch import pytest -from pytorch_lightning.callbacks import ModelCheckpoint as PTLModelCheckpoint -from pytorch_lightning.loggers import WandbLogger +from lightning.pytorch.callbacks import ModelCheckpoint as PTLModelCheckpoint +from lightning.pytorch.loggers import WandbLogger from nemo import lightning as nl from nemo.constants import NEMO_ENV_VARNAME_VERSION diff --git a/tests/lightning/test_precision_plugin.py b/tests/lightning/test_precision_plugin.py index 44ffa5939fab..960e658187c5 100644 --- a/tests/lightning/test_precision_plugin.py +++ b/tests/lightning/test_precision_plugin.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch from megatron.core.optimizer import OptimizerConfig diff --git a/tests/lightning/test_state_restoration.py b/tests/lightning/test_state_restoration.py index ccc0eed64d56..59c5cc2234f7 100644 --- a/tests/lightning/test_state_restoration.py +++ b/tests/lightning/test_state_restoration.py @@ -17,8 +17,8 @@ import pytest import torch +from lightning.pytorch.callbacks import Callback from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/tests/utils/test_trainer_utils.py b/tests/utils/test_trainer_utils.py index 55eee92a523c..251e59d4b648 100644 --- a/tests/utils/test_trainer_utils.py +++ b/tests/utils/test_trainer_utils.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.strategies import DDPStrategy from omegaconf import OmegaConf -from pytorch_lightning.strategies import DDPStrategy from nemo.utils.trainer_utils import resolve_trainer_cfg @@ -25,7 +25,7 @@ def test_resolve_trainer_cfg_strategy(): assert ans["strategy"] == "ddp" cfg = OmegaConf.create( - {"strategy": {"_target_": "pytorch_lightning.strategies.DDPStrategy", "gradient_as_bucket_view": True}} + {"strategy": {"_target_": "lightning.pytorch.strategies.DDPStrategy", "gradient_as_bucket_view": True}} ) ans = resolve_trainer_cfg(cfg) assert isinstance(ans, dict) diff --git a/tutorials/01_NeMo_Models.ipynb b/tutorials/01_NeMo_Models.ipynb index 4255a6656b8a..eb76e00cd981 100644 --- a/tutorials/01_NeMo_Models.ipynb +++ b/tutorials/01_NeMo_Models.ipynb @@ -984,7 +984,7 @@ "id": "0TsfmCYthMux" }, "source": [ - "import pytorch_lightning as ptl\n", + "import lightning.pytorch as ptl\n", "from nemo.core import ModelPT\n", "from omegaconf import OmegaConf" ], diff --git a/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb b/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb index a02ee4f99714..6ad3307da496 100644 --- a/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb +++ b/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb @@ -1292,7 +1292,7 @@ }, "source": [ "import torch\n", - "import pytorch_lightning as ptl\n", + "import lightning.pytorch as ptl\n", "\n", "if torch.cuda.is_available():\n", " accelerator = 'gpu'\n", @@ -2088,7 +2088,7 @@ }, "source": [ "import torch\n", - "import pytorch_lightning as ptl\n", + "import lightning.pytorch as ptl\n", "\n", "if torch.cuda.is_available():\n", " accelerator = 'gpu'\n", diff --git a/tutorials/asr/ASR_TTS_Tutorial.ipynb b/tutorials/asr/ASR_TTS_Tutorial.ipynb index 709f96d14ba5..544255f76d06 100644 --- a/tutorials/asr/ASR_TTS_Tutorial.ipynb +++ b/tutorials/asr/ASR_TTS_Tutorial.ipynb @@ -172,7 +172,7 @@ "import tempfile\n", "\n", "from omegaconf import OmegaConf\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import torch\n", "from tqdm.auto import tqdm\n", "import wget\n", diff --git a/tutorials/asr/ASR_with_NeMo.ipynb b/tutorials/asr/ASR_with_NeMo.ipynb index bd95c7194655..bb62e2f5eb9d 100644 --- a/tutorials/asr/ASR_with_NeMo.ipynb +++ b/tutorials/asr/ASR_with_NeMo.ipynb @@ -619,7 +619,7 @@ "id": "GUfR6tAK0k2u" }, "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=50)" ], "execution_count": null, diff --git a/tutorials/asr/ASR_with_Subword_Tokenization.ipynb b/tutorials/asr/ASR_with_Subword_Tokenization.ipynb index ff15a5f75532..7a69735ae542 100644 --- a/tutorials/asr/ASR_with_Subword_Tokenization.ipynb +++ b/tutorials/asr/ASR_with_Subword_Tokenization.ipynb @@ -765,7 +765,7 @@ "id": "3rslHEKeq9qy" }, "source": [ - "import pytorch_lightning as pl\r\n", + "import lightning.pytorch as pl\r\n", "trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=50)" ], "execution_count": null, diff --git a/tutorials/asr/ASR_with_Transducers.ipynb b/tutorials/asr/ASR_with_Transducers.ipynb index d20042b9b970..95eecbfb8916 100644 --- a/tutorials/asr/ASR_with_Transducers.ipynb +++ b/tutorials/asr/ASR_with_Transducers.ipynb @@ -754,7 +754,7 @@ "outputs": [], "source": [ "import torch\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch import Trainer\n", "\n", "if torch.cuda.is_available():\n", " accelerator = 'gpu'\n", diff --git a/tutorials/asr/Confidence_Ensembles.ipynb b/tutorials/asr/Confidence_Ensembles.ipynb index 734ddc9a0604..5a999df304b0 100644 --- a/tutorials/asr/Confidence_Ensembles.ipynb +++ b/tutorials/asr/Confidence_Ensembles.ipynb @@ -214,7 +214,7 @@ "# check out https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb\n", "# to learn more about finetuning NeMo ASR models\n", "from omegaconf import open_dict, OmegaConf\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch import Trainer\n", "\n", "from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE\n", "import nemo.utils.exp_manager as exp_manager\n", diff --git a/tutorials/asr/Multilang_ASR.ipynb b/tutorials/asr/Multilang_ASR.ipynb index 612271a8baab..800f8a2d2ded 100644 --- a/tutorials/asr/Multilang_ASR.ipynb +++ b/tutorials/asr/Multilang_ASR.ipynb @@ -1527,7 +1527,7 @@ "outputs": [], "source": [ "import torch\n", - "import pytorch_lightning as ptl" + "import lightning.pytorch as ptl" ] }, { diff --git a/tutorials/asr/Self_Supervised_Pre_Training.ipynb b/tutorials/asr/Self_Supervised_Pre_Training.ipynb index c2e1e7362b3e..0506bafb56e3 100644 --- a/tutorials/asr/Self_Supervised_Pre_Training.ipynb +++ b/tutorials/asr/Self_Supervised_Pre_Training.ipynb @@ -433,7 +433,7 @@ }, "outputs": [], "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf\n", "\n", "from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel\n", diff --git a/tutorials/asr/Speech_Commands.ipynb b/tutorials/asr/Speech_Commands.ipynb index 438533f0f03a..c8a54e5135b2 100644 --- a/tutorials/asr/Speech_Commands.ipynb +++ b/tutorials/asr/Speech_Commands.ipynb @@ -408,7 +408,7 @@ }, "source": [ "import torch\n", - "import pytorch_lightning as pl" + "import lightning.pytorch as pl" ], "execution_count": null, "outputs": [] diff --git a/tutorials/asr/Transducers_with_HF_Datasets.ipynb b/tutorials/asr/Transducers_with_HF_Datasets.ipynb index a47cd00a0b9a..82f17fe8c1ac 100644 --- a/tutorials/asr/Transducers_with_HF_Datasets.ipynb +++ b/tutorials/asr/Transducers_with_HF_Datasets.ipynb @@ -554,7 +554,7 @@ "outputs": [], "source": [ "import torch\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch import Trainer\n", "\n", "if torch.cuda.is_available():\n", " accelerator = 'gpu'\n", diff --git a/tutorials/asr/Voice_Activity_Detection.ipynb b/tutorials/asr/Voice_Activity_Detection.ipynb index 123a03efc28e..fb3cef1b44ea 100644 --- a/tutorials/asr/Voice_Activity_Detection.ipynb +++ b/tutorials/asr/Voice_Activity_Detection.ipynb @@ -425,7 +425,7 @@ "outputs": [], "source": [ "import torch\n", - "import pytorch_lightning as pl" + "import lightning.pytorch as pl" ] }, { diff --git a/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb b/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb index c9c547a8383e..c3334a59b0d2 100644 --- a/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb +++ b/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb @@ -260,7 +260,7 @@ "source": [ "import torch\n", "from omegaconf import OmegaConf, open_dict\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch import Trainer\n", "\n", "import nemo.collections.asr as nemo_asr" ], diff --git a/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb b/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb index cb364ab7396d..0d35feb11a9a 100644 --- a/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb +++ b/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb @@ -908,7 +908,7 @@ "\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", - "import pytorch_lightning as L\n", + "import lightning.pytorch as L\n", "\n", "from transformers import T5Tokenizer, T5ForConditionalGeneration\n", "\n", diff --git a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb index 5c697840ba09..faef27d18abf 100644 --- a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb +++ b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb @@ -91,7 +91,7 @@ "import IPython.display as ipd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import soundfile as sf\n", "\n", "from omegaconf import OmegaConf, open_dict\n", diff --git a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb index ff6970d98522..e8b734537a41 100644 --- a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb +++ b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb @@ -93,7 +93,7 @@ "import IPython.display as ipd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import soundfile as sf\n", "from pathlib import Path\n", "from omegaconf import OmegaConf, open_dict\n", @@ -981,4 +981,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tutorials/llm/llama-3/README.rst b/tutorials/llm/llama-3/README.rst index 3bb1a0896b82..1d12b8847c0d 100755 --- a/tutorials/llm/llama-3/README.rst +++ b/tutorials/llm/llama-3/README.rst @@ -2,7 +2,7 @@ Getting Started with Llama 3 and Llama 3.1 ========================================== -This repository contains jupyter notebook tutorials using NeMo Framework for Llama-3 and Llama-3.1 models by Meta. +This repository contains Jupyter Notebook tutorials using the NeMo Framework for Llama-3 and Llama-3.1 models by Meta. .. list-table:: :widths: 100 25 100 @@ -16,7 +16,7 @@ This repository contains jupyter notebook tutorials using NeMo Framework for Lla - Perform LoRA PEFT on Llama 3 8B Instruct using a dataset for bio-medical domain question answering. Deploy multiple LoRA adapters with NVIDIA NIM. * - `Llama 3.1 Law-Domain LoRA Fine-Tuning and Deployment with NeMo Framework and NVIDIA NIM <./sdg-law-title-generation>`_ - `Law StackExchange `_ - - Perform LoRA PEFT on Llama 3.1 8B Instruct using a synthetically augmented version of Law StackExchange with NeMo Framework, followed by deployment with NVIDIA NIM. As a pre-requisite, follow the tutorial for `data curation using NeMo Curator `__. + - Perform LoRA PEFT on Llama 3.1 8B Instruct using a synthetically augmented version of Law StackExchange with NeMo Framework, followed by deployment with NVIDIA NIM. As a prerequisite, follow the tutorial for `data curation using NeMo Curator `_. * - `Llama 3.1 Pruning and Distillation with NeMo Framework <./pruning-distillation>`_ - `WikiText-103-v1 `_ - Perform pruning and distillation on Llama 3.1 8B using the WikiText-103-v1 dataset with NeMo Framework. diff --git a/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb index 1f84dd2719e6..8548c0cfb1d0 100644 --- a/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb @@ -9,7 +9,7 @@ "\n", "The dataset has to be preprocessed using the [preprocess_data_for_megatron.py](https://github.com/NVIDIA/NeMo/blob/main/scripts/nlp_language_modeling/preprocess_data_for_megatron.py) script included in the NeMo Framework. This step will also tokenize data using the `meta-llama/Meta-Llama-3.1-8B` tokenizer model to convert the data into a memory map format.\n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your train, test and validation data files." + "> `NOTE:` In the block of code below, pass the paths to your train, test, and validation data files." ] }, { diff --git a/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb index 8d08793bbe9a..7d58ac4779aa 100644 --- a/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb @@ -6,15 +6,15 @@ "metadata": {}, "source": [ "\n", - "### Step 2: Finetune the teacher on the dataset\n", + "### Step 2: Fine-tune the teacher on the dataset\n", "\n", - "NeMo framework includes a standard python script [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py) for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", + "NeMo Framework includes a standard Python script, [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py), for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", "\n", - "We finetune the unpruned model on our dataset to correct the distribution shift across the original dataset the model was trained on. Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), experiments showed that, without correcting for the distribution shift, the teacher provides suboptimal guidance on the dataset when being distilled.\n", + "We fine-tune the unpruned model on our dataset to correct the distribution shift from the original dataset the model was trained on. According to the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), experiments showed that without correcting for this distribution shift, the teacher provides suboptimal guidance on the dataset during distillation.\n", "\n", "For this demonstration, this training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher .nemo model." + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as the path to the teacher .nemo model." ] }, { @@ -124,8 +124,8 @@ "id": "3040a993-8423-475f-8bc6-d1dd1ce16a83", "metadata": {}, "source": [ - "This will create a finetuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", - "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the finetuned teacher model." + "This will create a fine-tuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", + "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the fine-tuned teacher model." ] } ], diff --git a/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb index a195c2f3a405..d64f8c15bd00 100644 --- a/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb @@ -5,8 +5,8 @@ "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", "metadata": {}, "source": [ - "### Step 3: Prune the finetuned-teacher model to create a student\n", - "In this step, we will explore two methods to prune the finetuned teacher model. Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore.\n", + "### Step 3: Prune the fine-tuned teacher model to create a student\n", + "In this step, we will explore two methods to prune the fine-tuned teacher model. Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore.\n", "\n", "In the first method, depth-pruning, we trim the layers of the model." ] @@ -21,7 +21,7 @@ "\n", "Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), removing contiguous layers from the second last block (layers 16 to 31 continuously) yields the best overall results. \n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model." + "> `NOTE:` In the block of code below, pass the paths to your fine-tuned teacher .nemo model." ] }, { diff --git a/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb index 7d91d36cbb32..5c4a47872afb 100644 --- a/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb @@ -5,8 +5,8 @@ "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", "metadata": {}, "source": [ - "### Step 3: Prune the finetuned-teacher model to create a student\n", - "In the second method, we will width-prune. In width-pruning, we trim the neurons, attention heads and embedding channels. \n", + "### Step 3: Step 3: Prune the fine-tuned teacher model to create a student\n", + "In the second method, we will width-prune. In width-pruning, we trim the neurons, attention heads, and embedding channels.\n", "\n", "Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore." ] @@ -20,15 +20,15 @@ "source": [ "#### Step 3.b.: Using width-pruning\n", "To width-prune the model, we do the following:\n", - "- prune (trim) the MLP intermediate dimension from 14336 to 9216.\n", - "- prune the hidden size from 4096 to 3072.\n", - "- and retrain the attention headcount and number of layers\n", + "- Prune (trim) the MLP intermediate dimension from 14336 to 9216.\n", + "- Prune the hidden size from 4096 to 3072.\n", + "- Retrain the attention headcount and number of layers\n", "\n", - "For width-pruning we will use the [megatron_gpt_prune.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_prune.py) script in the NeMo Framework. To see the detailed list of parameters for width-pruning, you can view the [megatron_gpt_prune.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml) file.\n", + "For width-pruning, we will use the [megatron_gpt_prune.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_prune.py) script in the NeMo Framework. To see the detailed list of parameters for width-pruning, you can view the [megatron_gpt_prune.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml) file.\n", "\n", "We use the above parameters to get a competitive model for this demonstration. You can use other strategies or parameters from the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) or the [tech report](https://arxiv.org/pdf/2408.11796) for your experiments. \n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model.\n", + "> `NOTE:` In the block of code below, pass the paths to your fine-tuned teacher .nemo model.\n", "\n", "> `TIP:` You can increase the ``batch_size`` (upto 1024) to speed up the width-pruning script execution." ] diff --git a/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb index ccbe1cbf394b..488225837731 100644 --- a/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb @@ -6,9 +6,9 @@ "metadata": {}, "source": [ "### Step 4: Distill knowledge from teacher into student\n", - "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). In this notebook, we will explore distillation with the depth-pruned model as the `STUDENT` model. \n", + "Distillation of a model with NeMo Framework is also possible using a Python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). In this notebook, we will explore distillation with the depth-pruned model as the `STUDENT` model.\n", "\n", - "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + "For this demonstration, the `TEACHER` would be the fine-tuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." ] }, { @@ -19,7 +19,7 @@ "#### Step 4.a.: Using depth-pruned student\n", "While distilling knowledge from the teacher to depth-pruned model, the `STUDENT` model would be `4b_depth_pruned_model.nemo` as produced by the [depth-pruning](./03_a_depth_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as path to the teacher and student .nemo models." ] }, { diff --git a/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb index 48e81c96cdcf..95110dd19dd9 100644 --- a/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb @@ -6,10 +6,10 @@ "metadata": {}, "source": [ "### Step 4: Distill knowledge from teacher into student\n", - "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", + "Distillation of a model with NeMo Framework is also possible using a Python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", "In this notebook, we will explore distillation with the width-pruned model as the `STUDENT` model.\n", "\n", - "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + "For this demonstration, the `TEACHER` would be the fine-tuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." ] }, { @@ -20,7 +20,7 @@ "#### Step 4.b.: Using width-pruned student\n", "While distilling knowledge from the teacher to width-pruned model, the `STUDENT` model would be `4b_width_pruned_model.nemo` as produced by the [width-pruning](./03_b_width_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as path to the teacher and student .nemo models." ] }, { diff --git a/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb index 0264cc288957..dcb483c55ab6 100644 --- a/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb @@ -8,7 +8,8 @@ "### Step 5: Display the validation loss\n", "\n", "Now that the results are in, let's visualize the validation loss of the two distilled models using the `tensorboard` library. \n", - "> `NOTE:` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." + "\n", + "> `NOTE:` This notebook demonstrates the use of the teacher fine-tuning, pruning, and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." ] }, { @@ -16,8 +17,8 @@ "id": "b5822d62-8131-4046-8c22-0bf0fce81df7", "metadata": {}, "source": [ - "#### Validation Loss using depth-pruned model as student in distillation script\n", - "Here is an image of the validation loss over 30 steps of running the training step in the distillation script when we distill the knowledge from the finetuned teacher model to the depth-pruned student." + "#### Validation Loss Using Depth-Pruned Model as Student in Distillation Script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script, where we distill the knowledge from the fine-tuned teacher model to the depth-pruned student." ] }, { @@ -35,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "db6fcf26-8ae8-40e1-875a-0a10bf85be81", "metadata": { "tags": [] @@ -44,7 +45,7 @@ { "data": { "text/html": [ - "
Validation Loss over 30 Training Steps with Depth-Pruned model as Student
" + "
Validation Loss over 30 Training Steps with Depth-Pruned Model as Student
" ], "text/plain": [ "" @@ -68,7 +69,7 @@ ], "source": [ "from IPython.display import Image, display, HTML\n", - "title = \"Validation Loss over 30 Training Steps with Depth-Pruned model as Student\"\n", + "title = \"Validation Loss over 30 Training Steps with Depth-Pruned Model as Student\"\n", "display(HTML(f\"
{title}
\"))\n", "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_depth_pruned_student_distillation.png\", width=400))" ] @@ -78,8 +79,8 @@ "id": "f10041ae-6533-47de-9f76-f97d4469c27a", "metadata": {}, "source": [ - "#### Validation Loss using width-pruned model as student in distillation script\n", - "Here is an image of the validation loss over 30 steps of running the training step in the distillation script when we distill the knowledge from the finetuned teacher model to the width-pruned student." + "#### Validation Loss Using Width-Pruned Model as Student in Distillation Script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script, where we distill the knowledge from the fine-tuned teacher model to the width-pruned student." ] }, { @@ -97,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "ecd79583-f662-40c6-a690-9f4bb847de4e", "metadata": { "tags": [] @@ -106,7 +107,7 @@ { "data": { "text/html": [ - "
Validation Loss over 30 Training Steps with Width-Pruned model as Student
" + "
Validation Loss over 30 Training Steps with Width-Pruned Model as Student
" ], "text/plain": [ "" @@ -130,18 +131,10 @@ ], "source": [ "from IPython.display import Image, display, HTML\n", - "title = \"Validation Loss over 30 Training Steps with Width-Pruned model as Student\"\n", + "title = \"Validation Loss over 30 Training Steps with Width-Pruned Model as Student\"\n", "display(HTML(f\"
{title}
\"))\n", "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png\", width=400))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ab6ed6f-8bc3-4188-919f-7cee842635ed", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/tutorials/llm/llama-3/pruning-distillation/README.rst b/tutorials/llm/llama-3/pruning-distillation/README.rst index 34febcffa366..45cb119ffcd8 100644 --- a/tutorials/llm/llama-3/pruning-distillation/README.rst +++ b/tutorials/llm/llama-3/pruning-distillation/README.rst @@ -1,13 +1,13 @@ Llama 3.1 Pruning and Distillation with NeMo Framework ======================================================================================= -`Llama 3.1 `_ are open-source large language models by Meta that deliver state-of-the-art performance on popular industry benchmarks. They have been pretrained on over 15 trillion tokens, and support a 128K token context length. They are available in three sizes, 8B, 70B, and 405B, and each size has two variants—base pretrained and instruction tuned. +`Llama 3.1 `_ models, developed by Meta, are open-source large language models that deliver state-of-the-art performance on popular industry benchmarks. Pretrained on over 15 trillion tokens, they support a 128K token context length. These models are available in three sizes: 8B, 70B, and 405B. Each size offers two variants: base pretrained and instruction tuned. -`NVIDIA NeMo Framework `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 to fit your use case. +`NVIDIA NeMo Framework `_ provides tools to perform teacher fine-tuning, pruning, and distillation on Llama 3.1 to fit your use case. `NVIDIA TensorRT Model Optimizer `_ is a library (referred to as **Model Optimizer**, or **ModelOpt**) comprising state-of-the-art model optimization techniques including `quantization `_, `sparsity `_, `distillation `_, and `pruning `_ to compress models. -`LLM Pruning and Distillation in Practice: The Minitron Approach `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 as described in the `tech report `_. +`LLM Pruning and Distillation in Practice: The Minitron Approach `_ provides tools to perform teacher fine-tuning, pruning, and distillation on Llama 3.1 as described in the `tech report `_. `How to Prune and Distill Llama-3.1 8B to an NVIDIA Llama-3.1-Minitron 4B Model `_ provides practical and effective structured compression best practices for LLMs that combine depth, width, attention, and MLP pruning with knowledge distillation-based retraining. These strategies are presented in the `Compact Language Models via Pruning and Knowledge Distillation `_ paper. @@ -16,30 +16,33 @@ Llama 3.1 Pruning and Distillation with NeMo Framework Objectives ---------- -This tutorial shows how to perform depth-pruning, teacher finetuning and distillation on **Llama 3.1 8B** using the `WikiText-103-v1 `_ dataset with NeMo Framework. The `WikiText-103-v1 `_ language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. For this demonstration, we will perform teacher correction by running a light finetuning procedure on the ``Meta Llama 3.1 8B`` teacher model to generate a finetuned teacher model ``megatron_llama_ft.nemo`` needed for optimal distillation. This finetuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will be exploring both pruning techniques which will yield ``4b_depth_pruned_model.nemo`` and ``4b_width_pruned_model.nemo`` respectively. These models will serve as a starting point for distillation to create the final distilled 4B models. +This tutorial demonstrates how to perform depth-pruning, width-pruning, teacher fine-tuning, and distillation on **Llama 3.1 8B** using the `WikiText-103-v1 _ dataset with the NeMo Framework. The WikiText-103-v1 `_ language modeling dataset comprises over 100 million tokens extracted from verified Good and Featured articles on Wikipedia. + +For this demonstration, we will perform teacher correction by running a light fine-tuning procedure on the ``Meta LLama 3.1 8B`` teacher model to generate a fine-tuned teacher model, ``megatron_llama_ft.nemo``, needed for optimal distillation. This fine-tuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will explore both techniques, yielding ``4b_depth_pruned_model.nemo`` and ``4b_width_pruned_model.nemo``, respectively. These models will serve as starting points for distillation to create the final distilled 4B models. + We are using models utilizing the ``meta-llama/Meta-Llama-3.1-8B`` tokenizer for this demonstration. -``NOTE:`` A subset of functions is being demonstrated in the notebooks. Some features like Neural Architecture Search (NAS) are unavailable but will be supported in future releases. +``NOTE:`` A subset of functions is being demonstrated in the notebooks. Some features like Neural Architecture Search (NAS) are unavailable, but will be supported in future releases. Requirements ------------- * System Configuration - * Access to at least 8 NVIDIA GPU with an individual memory of at least 80GB, for example: 8 x H100-80GB or 8 x A100-80GB. + * Access to at least 8 NVIDIA GPUs, each with a memory of at least 80GB (e.g., 8 x H100-80GB or 8 x A100-80GB). * A Docker-enabled environment, with `NVIDIA Container Runtime `_ installed, which will make the container GPU-aware. -* `Authenticate with NVIDIA NGC `_, and download `NGC CLI Tool `_. You will use this tool to download the model and customize it with NeMo Framework. +* `Authenticate with NVIDIA NGC `_ and download `NGC CLI Tool `_. You will use this tool to download the model and customize it with NeMo Framework. * Get your Hugging Face `access token `_, which will be used to obtain the tokenizer required during training. -``NOTE:`` The default configuration in the notebook runs on 8 x 80GB NVIDIA GPUs but you can potentially reduce Tensor Parallel size ``(TENSOR_PARALLEL_SIZE)`` along with the Micro-Batchsize ``(MICRO_BATCH_SIZE)`` in the teacher finetuning and distillation scripts to accommodate lower resource availability. +``NOTE:`` The default configuration in the notebook runs on 8 x 80GB NVIDIA GPUs. However, you can potentially reduce the Tensor Parallel size ``(TENSOR_PARALLEL_SIZE)`` along with the Micro-Batchsize ``(MICRO_BATCH_SIZE)`` in the teacher fine-tuning and distillation scripts to accommodate lower resource availability. -Create a pruned and distilled model with NeMo Framework +Create a Pruned and Distilled Model with NeMo Framework ------------------------------------------------------------------------------ -For pruning and distilling the model, you will use the NeMo Framework which is available as a `docker container `_. +For pruning and distilling the model, you will use the NeMo Framework, which is available as a `Docker container `_. -``NOTE:`` These notebooks use `NVIDIA TensorRT Model Optimizer `_ under the hood for pruning and distillation. +``NOTE:`` These notebooks use the `NVIDIA TensorRT Model Optimizer `_ under the hood for pruning and distillation. 1. Download the `Llama 3.1 8B .nemo `_ from NVIDIA NGC using the `NGC CLI `_. Generate the ``NGC_API_KEY`` following these `instructions `_. The following command saves the ``.nemo`` format model in a folder named ``llama-3_1-8b-nemo_v1.0`` in the current directory. You can specify another path using the ``-d`` option in the CLI tool. @@ -75,7 +78,7 @@ For pruning and distilling the model, you will use the NeMo Framework which is a 4. Then, navigate to `this notebook <./introduction.ipynb>`_ to get started. -This directory contains a list of notebooks which will go over all the steps to create a distilled 4B model. +This directory contains a list of notebooks that cover all the steps to create a distilled 4B model. :: @@ -91,7 +94,7 @@ This directory contains a list of notebooks which will go over all the steps to Results ------------------------------------------------------------------------------ -``NOTE:`` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation scripts. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. +``NOTE:`` This notebook demonstrates the use of the teacher fine-tuning, pruning, and the distillation scripts. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. Here are the validation loss plots over 30 steps of running the training step in the distillation script (at the end of the `notebook <./05_display_results.ipynb>`_). @@ -100,11 +103,11 @@ Here are the validation loss plots over 30 steps of running the training step in :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the depth-pruned model as the student :align: center - Figure 1: Validation Loss Plot when using the depth-pruned model as the student + Figure 1: Validation Loss Plot When Using the Depth-Pruned Model as the Student .. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png :width: 400px :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the width-pruned model as the student :align: center - Figure 2: Validation Loss Plot when using the width-pruned model as the student \ No newline at end of file + Figure 2: Validation Loss Plot When Using the Width-Pruned Model as the Student \ No newline at end of file diff --git a/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb index 1a3efc9f5f1e..71a5a6cfb03c 100644 --- a/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb @@ -7,7 +7,7 @@ "tags": [] }, "source": [ - "# Pruning and Distillation of Llama 3.1 model with NeMo Framework" + "# Efficient Model Reduction with Pruning and Distillation of Llama 3.1 Using NeMo Framework" ] }, { @@ -15,15 +15,15 @@ "id": "03fd1cf4-c67a-4b8d-a5e5-46531be0f991", "metadata": {}, "source": [ - "This demonstration showcases performing pruning and distillation on **Llama 3.1-8B** with the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset using NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset is a collection of over 100 million tokens extracted from the set of verified 'Good' and 'Featured' articles on Wikipedia. \n", + "This tutorial demonstrates how to perform depth-pruning, teacher fine-tuning, and distillation on **Llama 3.1-8B** using the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset with NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset comprises over 100 million tokens extracted from verified Good and Featured articles on Wikipedia.\n", "\n", - "For this demonstration, we will perform a light finetuning procedure on the `Meta Llama 3.1 8B` teacher model to generate a finetuned teacher model. This finetuned teacher model will then be trimmed. There are two methods to prune a model: depth-pruning and width-pruning. This workflow will showcase both methods which will yield `4b_depth_pruned_model.nemo` and `4b_width_pruned_model.nemo` respectively, that will serve as a starting point for distillation to the final 4B models. \n", + "For this demonstration, we will perform teacher correction by running a light fine-tuning procedure on the `Meta Llama 3.1 8B` teacher model to generate a fine-tuned teacher model, `megatron_llama_ft.nemo`, needed for optimal distillation. This fine-tuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will explore both techniques, yielding `4b_depth_pruned_model.nemo` and `4b_width_pruned_model.nemo`, respectively. These models will serve as starting points for distillation to create the final distilled 4B models.\n", "\n", "> We are using models utilizing the `meta-llama/Meta-Llama-3.1-8B` tokenizer for this demonstration.\n", "\n", "> `NOTE:` Ensure that you run this notebook inside the [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) which has all the required dependencies. \n", "\n", - "**Instructions are available in the associated tutorial README to download the model and the container.**" + "**Instructions for downloading the model and the container are available in the [README](./README.rst).**" ] }, { @@ -49,8 +49,8 @@ "source": [ "---\n", "## Prerequisites\n", - "Ensure you have the following -\n", - "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo FW container." + "Ensure you meet the prerequisites listed in this section.\n", + "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo Framework container." ] }, { @@ -149,12 +149,12 @@ }, "source": [ "---\n", - "## Step-by-step instructions\n", + "## Step-by-Step Instructions\n", "\n", "This workflow is structured into seven notebooks:\n", "1. [Prepare the dataset](./01_data_preparation.ipynb)\n", - "2. [Finetune the teacher on the dataset](./02_teacher_finetuning.ipynb)\n", - "3. Prune the finetuned-teacher model to create a student \n", + "2. [Fine-tune the teacher on the dataset](./02_teacher_finetuning.ipynb)\n", + "3. Prune the fine-tuned teacher model to create a student \n", " - 3.a. [Using depth-pruning](./03_a_depth_pruning.ipynb)\n", " - 3.b. [Using width-pruning](./03_b_width_pruning.ipynb)\n", "4. Distill knowledge from teacher into student\n", @@ -162,7 +162,7 @@ " - 4.b. [Using width-pruned student](./04_b_distilling_width_pruned_student.ipynb)\n", "5. [Display the validation loss](./05_display_results.ipynb)\n", "\n", - "> `NOTE:` We are exploring two methods to prune the finetuned teacher model: [depth-pruning](./03_a_depth_pruning.ipynb) and [width-pruning](./03_b_width_pruning.ipynb). Per the [tech report](https://arxiv.org/pdf/2408.11796), we can observe that width-pruning generally outperforms depth-pruning so users can choose to perform either [depth-pruning](./03_a_depth_pruning.ipynb) or [width-pruning](./03_b_width_pruning.ipynb) or both methods." + "> `NOTE:` We are exploring two methods to prune the fine-tuned teacher model: [depth-pruning](./03_a_depth_pruning.ipynb) and [width-pruning](./03_b_width_pruning.ipynb). Per the [tech report](https://arxiv.org/pdf/2408.11796), we can observe that width-pruning generally outperforms depth-pruning so users can choose to perform either [depth-pruning](./03_a_depth_pruning.ipynb) or [width-pruning](./03_b_width_pruning.ipynb) or both methods." ] } ], diff --git a/tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb b/tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb index 6204bf2516bb..b028b2d5c190 100644 --- a/tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb +++ b/tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb @@ -249,7 +249,7 @@ "\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf\n", "import pandas as pd" ] diff --git a/tutorials/nlp/Punctuation_and_Capitalization.ipynb b/tutorials/nlp/Punctuation_and_Capitalization.ipynb index f88c33fada34..cbdab3941b6f 100644 --- a/tutorials/nlp/Punctuation_and_Capitalization.ipynb +++ b/tutorials/nlp/Punctuation_and_Capitalization.ipynb @@ -72,7 +72,7 @@ "import os\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb b/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb index 2afbb19c0e66..51d3a66c91fc 100644 --- a/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb +++ b/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb @@ -74,7 +74,7 @@ "import os\n", "import wget\n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb b/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb index d6b1e98b428e..3c9e427e7e09 100644 --- a/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb +++ b/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb @@ -71,7 +71,7 @@ "import os\n", "import wget\n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Text_Classification_Sentiment_Analysis.ipynb b/tutorials/nlp/Text_Classification_Sentiment_Analysis.ipynb index fdcff979ea46..0ed846881d02 100644 --- a/tutorials/nlp/Text_Classification_Sentiment_Analysis.ipynb +++ b/tutorials/nlp/Text_Classification_Sentiment_Analysis.ipynb @@ -58,7 +58,7 @@ "import os\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Token_Classification-BioMegatron.ipynb b/tutorials/nlp/Token_Classification-BioMegatron.ipynb index 85cb769b28c0..a59eae67dde1 100644 --- a/tutorials/nlp/Token_Classification-BioMegatron.ipynb +++ b/tutorials/nlp/Token_Classification-BioMegatron.ipynb @@ -45,7 +45,7 @@ "import os\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb b/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb index 3ab98f6c19fd..4c34c293dcca 100644 --- a/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb +++ b/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb @@ -94,7 +94,7 @@ "import os\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ], "execution_count": null, diff --git a/tutorials/nlp/Zero_Shot_Intent_Recognition.ipynb b/tutorials/nlp/Zero_Shot_Intent_Recognition.ipynb index 7f1baf536d87..b1eca63b8fd1 100644 --- a/tutorials/nlp/Zero_Shot_Intent_Recognition.ipynb +++ b/tutorials/nlp/Zero_Shot_Intent_Recognition.ipynb @@ -66,7 +66,7 @@ "from nemo.utils import logging\n", "from omegaconf import OmegaConf\n", "import pandas as pd\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import torch\n", "import wget " ] diff --git a/tutorials/nlp/lora.ipynb b/tutorials/nlp/lora.ipynb index c67fa6c2de15..0429dd7f053c 100644 --- a/tutorials/nlp/lora.ipynb +++ b/tutorials/nlp/lora.ipynb @@ -422,7 +422,7 @@ "source": [ "from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy\n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder\n", "\n", "# let's modify some trainer configs\n", diff --git a/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb b/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb index 7db905b6d225..c193e6600666 100644 --- a/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb +++ b/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb @@ -777,7 +777,7 @@ "metadata": {}, "outputs": [], "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from nemo.collections.asr.models import EncDecDiarLabelModel\n", "from nemo.utils.exp_manager import exp_manager\n", "\n", diff --git a/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb b/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb index 27a01b894eae..c4f7fbaca67e 100644 --- a/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb +++ b/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb @@ -438,7 +438,7 @@ "outputs": [], "source": [ "import torch\n", - "import pytorch_lightning as pl" + "import lightning.pytorch as pl" ] }, { diff --git a/tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb b/tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb index afd202f99d4a..8b0114690540 100644 --- a/tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb +++ b/tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb @@ -1636,7 +1636,7 @@ "outputId": "67209ee3-5161-40dc-a179-83d8219c3d71" }, "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import DictConfig\n", "import copy\n", "\n", diff --git a/tutorials/tts/Tacotron2_Training.ipynb b/tutorials/tts/Tacotron2_Training.ipynb index 79546bb79db9..edc814cf12ec 100644 --- a/tutorials/tts/Tacotron2_Training.ipynb +++ b/tutorials/tts/Tacotron2_Training.ipynb @@ -178,7 +178,7 @@ "Let's take a look at the tacotron2.py file\n", "\n", "```python\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "\n", "from nemo.collections.common.callbacks import LogEpochTimeCallback\n", "from nemo.collections.tts.models import Tacotron2Model\n", diff --git a/tutorials/tts/Vits_Training.ipynb b/tutorials/tts/Vits_Training.ipynb index 9d3919e8dc6a..060c6bda43bb 100644 --- a/tutorials/tts/Vits_Training.ipynb +++ b/tutorials/tts/Vits_Training.ipynb @@ -191,7 +191,7 @@ "Let's take a look at the vits.py file\n", "\n", "```python\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "\n", "from nemo.collections.tts.models.vits import VitsModel\n", "from nemo.core.config import hydra_runner\n",