diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 797b7888b01e..8100d95ae2a3 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -159,6 +159,21 @@ jobs: rm -f /home/TestData/nlp/megatron_ir/sbert/sbert.nemo rm -rf /home/TestData/nlp/megatron_ir/sbert/model_weights + L2_Community_LLM_Checkpoints_tests_Mamba2: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ + --input_name_or_path /home/TestData/nlp/megatron_mamba/model_optim_rng.pt \ + --output_path /home/TestData/nlp/megatron_mamba/converted_mamba.nemo \ + --precision=bf16 \ + --mamba_ssm_ngroups 1 + AFTER_SCRIPT: | + rm -f /home/TestData/nlp/megatron_mamba/converted_mamba.nemo + rm -rf /home/TestData/nlp/megatron_mamba/model_weights + L2_Community_LLM_Checkpoints_tests_Llama: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -484,41 +499,41 @@ jobs: AFTER_SCRIPT: | rm -rf examples/asr/speech_finetuning_results - OPTIONAL_ASR_dev_run_Speech_To_Text_HF_Finetuning: - needs: [cicd-test-container-setup] - uses: ./.github/workflows/_test_template.yml - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - python examples/asr/speech_to_text_finetune.py \ - --config-path="conf/asr_finetune" --config-name="speech_to_text_hf_finetune" \ - ~model.train_ds.hf_data_cfg \ - model.train_ds.num_workers=1 \ - model.train_ds.batch_size=2 model.validation_ds.batch_size=2 \ - model.train_ds.streaming=true \ - +model.train_ds.hf_data_cfg.path="librispeech_asr" \ - +model.train_ds.hf_data_cfg.name=null \ - +model.train_ds.hf_data_cfg.split="test.clean" \ - +model.train_ds.hf_data_cfg.streaming=true \ - ~model.validation_ds.hf_data_cfg \ - model.validation_ds.streaming=true \ - +model.validation_ds.hf_data_cfg.path="librispeech_asr" \ - +model.validation_ds.hf_data_cfg.name=null \ - +model.validation_ds.hf_data_cfg.split="test.clean" \ - +model.validation_ds.hf_data_cfg.streaming=true \ - ~model.test_ds \ - init_from_nemo_model=/home/TestData/asr/stt_en_fastconformer_transducer_large.nemo \ - model.tokenizer.update_tokenizer=False \ - model.optim.sched.warmup_steps=0 \ - +model.optim.sched.max_steps=3 \ - trainer.max_epochs=null \ - trainer.devices=1 \ - trainer.accelerator="gpu" \ - +trainer.fast_dev_run=True \ - exp_manager.exp_dir=examples/asr/speech_finetuning_results - AFTER_SCRIPT: | - rm -rf examples/asr/speech_finetuning_results - IS_OPTIONAL: true + # OPTIONAL_ASR_dev_run_Speech_To_Text_HF_Finetuning: + # needs: [cicd-test-container-setup] + # uses: ./.github/workflows/_test_template.yml + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # python examples/asr/speech_to_text_finetune.py \ + # --config-path="conf/asr_finetune" --config-name="speech_to_text_hf_finetune" \ + # ~model.train_ds.hf_data_cfg \ + # model.train_ds.num_workers=1 \ + # model.train_ds.batch_size=2 model.validation_ds.batch_size=2 \ + # model.train_ds.streaming=true \ + # +model.train_ds.hf_data_cfg.path="librispeech_asr" \ + # +model.train_ds.hf_data_cfg.name=null \ + # +model.train_ds.hf_data_cfg.split="test.clean" \ + # +model.train_ds.hf_data_cfg.streaming=true \ + # ~model.validation_ds.hf_data_cfg \ + # model.validation_ds.streaming=true \ + # +model.validation_ds.hf_data_cfg.path="librispeech_asr" \ + # +model.validation_ds.hf_data_cfg.name=null \ + # +model.validation_ds.hf_data_cfg.split="test.clean" \ + # +model.validation_ds.hf_data_cfg.streaming=true \ + # ~model.test_ds \ + # init_from_nemo_model=/home/TestData/asr/stt_en_fastconformer_transducer_large.nemo \ + # model.tokenizer.update_tokenizer=False \ + # model.optim.sched.warmup_steps=0 \ + # +model.optim.sched.max_steps=3 \ + # trainer.max_epochs=null \ + # trainer.devices=1 \ + # trainer.accelerator="gpu" \ + # +trainer.fast_dev_run=True \ + # exp_manager.exp_dir=examples/asr/speech_finetuning_results + # AFTER_SCRIPT: | + # rm -rf examples/asr/speech_finetuning_results + # IS_OPTIONAL: true ASR_dev_run_Speech_to_Text_WPE_-_Conformer: needs: [cicd-test-container-setup] @@ -2046,7 +2061,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + python examples/nlp/language_modeling/megatron_bert_pretraining.py \ trainer.devices=2 \ trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ @@ -2076,7 +2091,7 @@ jobs: model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + python examples/nlp/language_modeling/megatron_bert_pretraining.py \ trainer.devices=2 \ trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ @@ -2113,7 +2128,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + python examples/nlp/language_modeling/megatron_bert_pretraining.py \ trainer.devices=2 \ trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ @@ -2144,7 +2159,7 @@ jobs: model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + python examples/nlp/language_modeling/megatron_bert_pretraining.py \ trainer.devices=2 \ trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ @@ -2184,7 +2199,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + python examples/nlp/language_modeling/megatron_bert_pretraining.py \ trainer.devices=2 \ trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ @@ -2214,7 +2229,7 @@ jobs: model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + python examples/nlp/language_modeling/megatron_bert_pretraining.py \ trainer.devices=2 \ trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ @@ -4738,6 +4753,22 @@ jobs: rm -rf examples/llm/gpt_pretrain_results rm -rf examples/llm/gpt_index_mappings + L2_NeMo_2_GPT_DDP_Param_Parity_check: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/lightning/test_ddp_parity_checker.py \ + --vocab-path=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + --merges-path=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + --data-path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document + + AFTER_SCRIPT: | + rm -rf examples/llm/gpt_pretrain_results + rm -rf examples/llm/gpt_index_mappings + Nemo_CICD_Test: needs: - gpu-test @@ -4745,6 +4776,7 @@ jobs: - L0_Unit_Tests_GPU #- OPTIONAL_L0_Unit_Tests_CPU - L2_Community_LLM_Checkpoints_tests_Bert + - L2_Community_LLM_Checkpoints_tests_Mamba2 - L2_Community_LLM_Checkpoints_tests_Llama - L2_Community_LLM_Checkpoints_tests_StarCoder - L2_Community_LLM_Checkpoints_tests_Falcon @@ -4843,6 +4875,7 @@ jobs: - Speech_Checkpoints_tests #- OPTIONAL_L2_Stable_Diffusion_Training - L2_NeMo_2_GPT_Pretraining_no_transformer_engine + - L2_NeMo_2_GPT_DDP_Param_Parity_check if: always() runs-on: ubuntu-latest steps: diff --git a/Dockerfile.ci b/Dockerfile.ci index 38b82a288a2b..275aaecb95f0 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -34,7 +34,7 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.15.0 -ARG MCORE_TAG=2fd6e2b74efca73a1f2d27b89bb5419384b4d3bf +ARG MCORE_TAG=34e607ef41cf1c0ed481a678df9c76952d0ec00c ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ diff --git a/docs/source/asr/asr_language_modeling_and_customization.rst b/docs/source/asr/asr_language_modeling_and_customization.rst index d5a748e2379e..02fed8b89760 100644 --- a/docs/source/asr/asr_language_modeling_and_customization.rst +++ b/docs/source/asr/asr_language_modeling_and_customization.rst @@ -547,6 +547,69 @@ The following is the list of the arguments for the opengrm script: | force | bool | ``False`` | Whether to recompile and rewrite all files | +----------------------+--------+------------------+-----------------------------------------------------------------------------------------------------------------+ +.. _wfst-ctc-decoding: + +WFST CTC decoding +================= +Weighted Finite-State Transducers (WFST) are finite-state machines with input and output symbols on each transition and some weight element of a semiring. WFSTs can act as N-gram LMs in a special type of LM-forced beam search, called WFST decoding. + +.. note:: + + More precisely, WFST decoding is more of a greedy N-depth search with LM. + Thus, it is asymptotically worse than conventional beam search decoding algorithms, but faster. + +**WARNING** +At the moment, NeMo supports WFST decoding only for CTC models and word-based LMs. + +To run WFST decoding in NeMo, one needs to provide a NeMo ASR model and either an ARPA LM or a WFST LM (advanced). An ARPA LM can be built from source text with KenLM as follows: ``/lmplz -o --arpa --prune ``. + +The script to evaluate an ASR model with WFST decoding and N-gram models can be found at +`scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py +`__. + +This script has a large number of possible argument overrides, therefore it is advised to use ``python eval_wfst_decoding_ctc.py --help`` to see the full list of arguments. + +You may evaluate an ASR model as the following: + +.. code-block:: + + python eval_wfst_decoding_ctc.py nemo_model_file= \ + input_manifest= \ + arpa_model_file= \ + decoding_wfst_file= \ + beam_width=[] \ + lm_weight=[] \ + open_vocabulary_decoding= \ + decoding_mode= \ + decoding_search_type= \ + preds_output_folder= \ + probs_cache_file=null + +.. note:: + + Since WFST decoding is LM-forced (the search goes over the WIDEST graph), only word sequences accepted by the WFST can appear in the decoding results. + To circumvent this restriction, one can pass ``open_vocabulary_decoding=true`` (experimental feature). + + +Quick start example +------------------- + +.. code-block:: + + wget -O - https://www.openslr.org/resources/11/3-gram.pruned.1e-7.arpa.gz | \ + gunzip -c | tr '[:upper:]' '[:lower:]' > 3-gram.pruned.1e-7.arpa && \ + python eval_wfst_decoding_ctc.py nemo_model_file="stt_en_conformer_ctc_small_ls" \ + input_manifest="/Librispeech/test_other.json" \ + arpa_model_file="3-gram.pruned.1e-7.arpa" \ + decoding_wfst_file="3-gram.pruned.1e-7.fst" \ + beam_width=[8] \ + lm_weight=[0.5,0.6,0.7,0.8,0.9] + +.. note:: + + Building a decoding WFST is a long process, so it is better to provide a ``decoding_wfst_file`` path even if you don't have it. + This way, the decoding WFST will be buffered to the specified file path and there will be no need to re-build it on the next run. + *************************************************** Context-biasing (word boosting) without external LM diff --git a/docs/source/multimodal/text2img/sd.rst b/docs/source/multimodal/text2img/sd.rst index 6f5092f93f5f..549f13bbabf6 100644 --- a/docs/source/multimodal/text2img/sd.rst +++ b/docs/source/multimodal/text2img/sd.rst @@ -163,7 +163,7 @@ Optimization related configurations Training with precached latents ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Since the VAE and text encoder remain frozed during training, you can pre-calculate the image and caption latents offline, enhancing training throughput. To create a pre-cached dataset, see :doc:`Multimodal Dataset <./datasets>`. For training using this dataset, configure ``model.data`` section properly and set ``model.first_stage_key=image_encoded`` along with ``model.cond_stage_key=captions_encoded``. +Since the VAE and text encoder remain frozen during training, you can pre-calculate the image and caption latents offline, enhancing training throughput. To create a pre-cached dataset, see :doc:`Multimodal Dataset <./datasets>`. For training using this dataset, configure ``model.data`` section properly and set ``model.first_stage_key=image_encoded`` along with ``model.cond_stage_key=captions_encoded``. Reference ----------- diff --git a/docs/source/nlp/nemo_megatron/intro.rst b/docs/source/nlp/nemo_megatron/intro.rst index fab448f3d4f2..65aaee2add6a 100644 --- a/docs/source/nlp/nemo_megatron/intro.rst +++ b/docs/source/nlp/nemo_megatron/intro.rst @@ -20,6 +20,7 @@ To learn more about using NeMo to train Large Language Models at scale, please r peft/landing_page positional_embeddings mcore_customization + rampup_batch_size References @@ -28,4 +29,4 @@ References .. bibliography:: ../nlp_all.bib :style: plain :labelprefix: nlp-megatron - :keyprefix: nlp-megatron- \ No newline at end of file + :keyprefix: nlp-megatron- diff --git a/docs/source/nlp/nemo_megatron/rampup_batch_size.rst b/docs/source/nlp/nemo_megatron/rampup_batch_size.rst new file mode 100644 index 000000000000..1e396cbc7630 --- /dev/null +++ b/docs/source/nlp/nemo_megatron/rampup_batch_size.rst @@ -0,0 +1,62 @@ +.. _rampup_batch_size: + +Ramp Up Batch Size +------------------ + +Ramp up batch size is a feature that allows training to start with a smaller global batch size and linearly increase to a target global batch size over a given number of training samples with specified incremental steps. + +Usage +----- + +To enable global batch size rampup during training, set the rampup_batch_size parameter under the model section of training configuration. This parameter should be a list of three values: + +* ``start_batch_size``: The initial batch size. +* ``batch_size_increment``: The amount by which the batch size will increase at each step. +* ``rampup_samples``: The number of training samples over which the batch size will be ramped up. + +``model.global_batch_size=1024 model.rampup_batch_size=[256, 128, 50000000]`` + +In this example, the training will start with a batch size of 256, increment by 128, and reach the target global batch size of 1024 over 50,000,000 training samples. + +Ramp Up Stages and Training Interruption +---------------------------------------- + +Once the next rampup stage is reached (the point in training when the global batch size increases), NeMo will stop the training. It allows to rerun the training job with a larger number of GPUs or nodes for the next stage of ramp up batch size. + +Automatic Node Scheduling +------------------------- + +In the `NeMo-Framework-Launcher `_, when using rampup batch size, a node scheduler is created automatically. This scheduler allows the use smaller number of nodes for smaller batch size stages and scales up according to the ``training.trainer.num_nodes`` parameter. This parameter corresponds to the maximum number of nodes you want to use for the maximum global batch size. + +Example +------- + +Detailed example of ramp up batch size feature usage with GPT3 5B model and `NeMo-Framework-Launcher `_. In this example, the training started with a global batch size of 256, increased by 256 at each ramp up stage, and reached the target global batch size of 2048 over 10,000,000 training samples. + +Node schedule looks as follows: + ++--------------------+--------------------+ +| global_batch_size | num_nodes | ++====================+====================+ +| 256 | 8 | ++--------------------+--------------------+ +| 512 | 8 | ++--------------------+--------------------+ +| 768 | 8 | ++--------------------+--------------------+ +| 1024 | 8 | ++--------------------+--------------------+ +| 1280 | 10 | ++--------------------+--------------------+ +| 1536 | 12 | ++--------------------+--------------------+ +| 1792 | 14 | ++--------------------+--------------------+ +| 2048 | 16 | ++--------------------+--------------------+ + +Plot of ``global_batch_size`` increase during training: + +.. image:: https://github.com/NVIDIA/NeMo/releases/download/v2.0.0rc0/asset-post-rampup-batch-size-example.png + :alt: + :width: 1080px diff --git a/docs/source/performance/performance_summary.md b/docs/source/performance/performance_summary.md index c5bdda7b040d..eca42f2d0695 100644 --- a/docs/source/performance/performance_summary.md +++ b/docs/source/performance/performance_summary.md @@ -11,18 +11,18 @@ | Model | #-GPUs | GBS | MBS | Sequence Length| TP | PP | CP | VP | Tokens / sec / GPU | Model TFLOP / sec / GPU | ***Est. time to train in days (10T tokens, 1K GPUs)*** | | ----- | ------ | --- | --- | ---------------| -- | -- | -- | -- | ------------------ | ----------------------- | ------------------------------------------------------ | -| GPT3-5B | 64 | 2048 | 4 | 2048 | 1 | 1 | 1 | 1 | 22521 | 736 | ***5*** | -| GPT3-20B | 64 | 256 | 2 | 2048 | 2 | 1 | 1 | 1 | 5851 | 750 | ***19*** | -| GPT3-175B | 128 | 256 | 1 | 2048 | 4 | 8 | 1 | 6 | 726 | 782 | **156** | -| GPT3-175B | 512 | 2048 | 2 | 2048 | 4 | 8 | 1 | 6 | 782 | [842](https://mlcommons.org/benchmarks/training/) | **145** | -| LLAMA2-7B | 8 | 128 | 1 | 4096 | 1 | 1 | 1 | 1 | 16847 | 776 | ***7*** | -| LLAMA2-13B | 16 | 128 | 1 | 4096 | 1 | 4 | 1 | 10 | 8646 | 754 | ***13*** | -| LLAMA2-70B | 64 | 128 | 1 | 4096 | 4 | 4 | 1 | 20 | 1707 | 759 | ***66*** | -| Nemotron-8B | 64 | 256 | 4 | 4096 | 2 | 1 | 1 | 1 | 12701 | 653 | ***9*** | -| Nemotron-22B | 64 | 256 | 2 | 4096 | 2 | 4 | 1 | 10 | 4256 | 554 | ***27*** | -| Nemotron-340B | 128 | 32 | 1 | 4096 | 8 | 8 | 1 | 12 | 322 | 678 | ***351*** | -| LLAMA3-8B | 8 | 128 | 1 | 8192 | 1 | 1 | 2 | 1 | 12036 | 697 | ***9*** | -| LLAMA3-70B | 64 | 128 | 1 | 8192 | 4 | 4 | 2 | 5 | 1533 | 738 | ***74*** | +| GPT3-5B | 64 | 2048 | 4 | 2048 | 1 | 1 | 1 | 1 | 23574 | 770 | ***5*** | +| GPT3-20B | 64 | 256 | 2 | 2048 | 2 | 1 | 1 | 1 | 5894 | 755 | ***19*** | +| GPT3-175B | 128 | 256 | 1 | 2048 | 4 | 8 | 1 | 6 | 745 | 802 | **152** | +| GPT3-175B | 512 | 2048 | 2 | 2048 | 4 | 8 | 1 | 6 | 832 | [895](https://mlcommons.org/benchmarks/training/) | **136** | +| LLAMA2-7B | 8 | 128 | 1 | 4096 | 1 | 1 | 1 | 1 | 16634 | 767 | ***7*** | +| LLAMA2-13B | 16 | 128 | 1 | 4096 | 1 | 4 | 1 | 10 | 8715 | 760 | ***13*** | +| LLAMA2-70B | 64 | 128 | 1 | 4096 | 4 | 4 | 1 | 20 | 1717 | 763 | ***66*** | +| Nemotron-8B | 64 | 256 | 4 | 4096 | 2 | 1 | 1 | 1 | 12507 | 643 | ***9*** | +| Nemotron-22B | 64 | 256 | 2 | 4096 | 2 | 4 | 1 | 10 | 4289 | 559 | ***26*** | +| Nemotron-340B | 128 | 32 | 1 | 4096 | 8 | 8 | 1 | 12 | 328 | 691 | ***344*** | +| LLAMA3-8B | 8 | 128 | 1 | 8192 | 1 | 1 | 2 | 1 | 11883 | 688 | ***10*** | +| LLAMA3-70B | 64 | 128 | 1 | 8192 | 4 | 4 | 2 | 5 | 1549 | 746 | ***73*** | ### Finetuning @@ -34,9 +34,9 @@ | Model | Task | #-GPUs | GBS | MBS | Packed Sequence Length | TP | PP | Tokens / sec / GPU | Model TFLOP / sec / GPU | ***Est. time to finetune in mins (10M tokens)*** | | ----- | ---- | --- | --- | --- | --------------- | -- | -- | ------------------ | ----------------------- | -------------------------------------------------- | -| LLAMA2-7B | SFT | 8 | 32 | 1 | 4096 | 1 | 1 | 17120 | 682 | ***1.2*** | -| LLAMA2-13B | SFT | 8 | 32 | 1 | 4096 | 1 | 4 | 9741 | 754 | ***2.1*** | -| LLAMA2-70B | SFT | 16 | 32 | 1 | 4096 | 4 | 4 | 1833 | 756 | ***5.7*** | +| LLAMA2-7B | SFT | 8 | 32 | 1 | 4096 | 1 | 1 | 17617 | 702 | ***1.2*** | +| LLAMA2-13B | SFT | 8 | 32 | 1 | 4096 | 1 | 4 | 10176 | 787 | ***2.0*** | +| LLAMA2-70B | SFT | 16 | 32 | 1 | 4096 | 4 | 4 | 1812 | 747 | ***5.7*** | | LLAMA2-7B | LoRA | 8 | 32 | 1 | 4096 | 1 | 1 | 25206 | 673 | ***0.8*** | -| LLAMA2-13B | LoRA | 8 | 32 | 1 | 4096 | 1 | 1 | 14161 | 733 | ***1.5*** | -| LLAMA2-70B | LoRA | 8 | 32 | 1 | 4096 | 2 | 4 | 2557 | 705 | ***8.1*** | +| LLAMA2-13B | LoRA | 8 | 32 | 1 | 4096 | 1 | 1 | 14760 | 764 | ***1.4*** | +| LLAMA2-70B | LoRA | 8 | 32 | 1 | 4096 | 2 | 4 | 2621 | 722 | ***7.9*** | diff --git a/examples/asr/asr_adapters/train_asr_adapter.py b/examples/asr/asr_adapters/train_asr_adapter.py index 5a94e2bb332d..3f82ef8fe554 100644 --- a/examples/asr/asr_adapters/train_asr_adapter.py +++ b/examples/asr/asr_adapters/train_asr_adapter.py @@ -92,6 +92,7 @@ from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import clean_exp_ckpt, exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg def update_model_config_to_support_adapter(model_cfg, current_cfg): @@ -154,7 +155,7 @@ def main(cfg): if cfg.model.pretrained_model is not None and cfg.model.nemo_model is not None: raise ValueError("Cannot set both `cfg.model.nemo_model` and `cfg.model.pretrained_model`. Select one only.") - trainer = pl.Trainer(**cfg.trainer) + trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer)) exp_log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) if cfg.model.pretrained_model is not None: diff --git a/examples/asr/conf/asr_adapters/asr_adaptation.yaml b/examples/asr/conf/asr_adapters/asr_adaptation.yaml index 6ab3f12d6a1a..b9a2a003217e 100644 --- a/examples/asr/conf/asr_adapters/asr_adaptation.yaml +++ b/examples/asr/conf/asr_adapters/asr_adaptation.yaml @@ -181,7 +181,9 @@ trainer: max_steps: 1000 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: null precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. diff --git a/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml b/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml index 4afbc3b51c29..958e6d23375c 100644 --- a/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml +++ b/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml @@ -181,7 +181,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: null precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. diff --git a/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml index 6808f4941916..3b5717efddf9 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml @@ -80,7 +80,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml index 172d09ccd60b..e6d9b0b49c65 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml @@ -138,7 +138,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml index acb499f18ffb..4c80d2f2e9d4 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml @@ -171,7 +171,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml index 8dd978bb00e4..0796a60260a1 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml @@ -176,7 +176,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml index 9f199c2dd488..4edcc38396fa 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml @@ -227,7 +227,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml index c7f83216aa0b..97b64ef93402 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml @@ -233,7 +233,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml index 9b51edf614b8..d8808b83069c 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml @@ -195,7 +195,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml index 680d96e1afaf..90a77dee2913 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml @@ -248,7 +248,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml index 6f356ce91caa..daef1ed67a9f 100644 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml @@ -244,7 +244,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml index 870bb0190c03..96aee4af1803 100644 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml @@ -249,7 +249,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml index 3fc91cc1e436..4ba55e368bb9 100644 --- a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml @@ -223,7 +223,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml index e99ba69df57a..ed2ad8ca9c0d 100644 --- a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml +++ b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml @@ -228,7 +228,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml index 3e3d2bf6788e..773a500ef2db 100644 --- a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml @@ -168,7 +168,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml index 5f6c37288ae9..fec2a2839efa 100644 --- a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml @@ -222,7 +222,9 @@ trainer: max_steps: -1 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto - strategy: ddp + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: 32 # 16, 32, or bf16 diff --git a/examples/audio/audio_to_audio_train.py b/examples/audio/audio_to_audio_train.py index b197d2084144..cef46dcf20b6 100644 --- a/examples/audio/audio_to_audio_train.py +++ b/examples/audio/audio_to_audio_train.py @@ -34,6 +34,7 @@ from nemo.collections.audio.models.enhancement import ( EncMaskDecAudioToAudioModel, + FlowMatchingAudioToAudioModel, PredictiveAudioToAudioModel, SchroedingerBridgeAudioToAudioModel, ScoreBasedGenerativeAudioToAudioModel, @@ -50,6 +51,7 @@ class ModelType(str, Enum): Predictive = 'predictive' ScoreBased = 'score_based' SchroedingerBridge = 'schroedinger_bridge' + FlowMatching = 'flow_matching' def get_model_class(model_type: ModelType): @@ -62,6 +64,8 @@ def get_model_class(model_type: ModelType): return ScoreBasedGenerativeAudioToAudioModel elif model_type == ModelType.SchroedingerBridge: return SchroedingerBridgeAudioToAudioModel + elif model_type == ModelType.FlowMatching: + return FlowMatchingAudioToAudioModel else: raise ValueError(f'Unknown model type: {model_type}') diff --git a/examples/audio/conf/flow_matching_generative.yaml b/examples/audio/conf/flow_matching_generative.yaml new file mode 100644 index 000000000000..5f644f328e6d --- /dev/null +++ b/examples/audio/conf/flow_matching_generative.yaml @@ -0,0 +1,164 @@ +name: flow_matching_generative + +model: + type: flow_matching + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + p_cond: 0.9 # Proability of feeding the conditional input into the model. + normalize_input: true # normalize the input signal to 0dBFS + max_utts_evaluation_metrics: 500 + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768 + random_offset: true + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + batch_size: 8 + shuffle: false + num_workers: 4 + pin_memory: true + + log_config: + log_tensorboard: true + log_wandb: false + max_utts: 8 + + encoder: + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel score estimate + depth: 24 + ff_dropout: 0.1 + time_hidden_dim: 1024 + + flow: + _target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow + sigma_start: 1.0 + sigma_end: 1e-4 + + sampler: + _target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler + num_steps: 20 + time_min: 1e-8 + time_max: 1.0 + + loss: + _target_: nemo.collections.audio.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + estoi: # output ESTOI + _target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility + fs: ${model.sample_rate} + extended: true + pesq: # output PESQ + _target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality + fs: ${model.sample_rate} + mode: wb + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 5000 + warmup_ratio: null + min_lr: 0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.2 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_pesq + mode: max + save_top_k: 3 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: test + project: gense diff --git a/examples/audio/conf/flow_matching_generative_finetuning.yaml b/examples/audio/conf/flow_matching_generative_finetuning.yaml new file mode 100644 index 000000000000..c7ba19aee466 --- /dev/null +++ b/examples/audio/conf/flow_matching_generative_finetuning.yaml @@ -0,0 +1,167 @@ +name: flow_matching_generative_finetuning + +init_from_nemo_model: null +init_strict: false + +model: + type: flow_matching + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + p_cond: 0.9 # Proability of feeding the conditional input into the model. + normalize_input: true # normalize the input signal to 0dBFS + max_utts_evaluation_metrics: 500 + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768 + random_offset: true + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + batch_size: 8 + shuffle: false + num_workers: 4 + pin_memory: true + + log_config: + log_tensorboard: true + log_wandb: false + max_utts: 8 + + encoder: + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel score estimate + depth: 24 + ff_dropout: 0.1 + time_hidden_dim: 1024 + + flow: + _target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow + sigma_start: 1.0 + sigma_end: 1e-4 + + sampler: + _target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler + num_steps: 20 + time_min: 1e-8 + time_max: 1.0 + + loss: + _target_: nemo.collections.audio.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + estoi: # output ESTOI + _target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility + fs: ${model.sample_rate} + extended: true + pesq: # output PESQ + _target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality + fs: ${model.sample_rate} + mode: wb + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 5000 + warmup_ratio: null + min_lr: 0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.2 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_pesq + mode: max + save_top_k: 3 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: test + project: gense diff --git a/examples/audio/conf/flow_matching_generative_ssl_pretraining.yaml b/examples/audio/conf/flow_matching_generative_ssl_pretraining.yaml new file mode 100644 index 000000000000..7813a9473644 --- /dev/null +++ b/examples/audio/conf/flow_matching_generative_ssl_pretraining.yaml @@ -0,0 +1,171 @@ +name: flow_matching_generative_ssl_pretraining + +model: + type: flow_matching + sample_rate: 16000 + skip_nan_grad: true + num_outputs: 1 + p_cond: 0.9 # Proability of feeding the conditional input into the model. + normalize_input: true # normalize the input signal to 0dBFS + max_utts_evaluation_metrics: 125 + + train_ds: + shar_path: ??? + use_lhotse: true + truncate_duration: 4.09 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 512 + truncate_offset_type: random + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: clean_filepath + target_key: clean_filepath + random_offset: false + batch_size: 8 + shuffle: false + num_workers: 4 + pin_memory: true + + log_config: + log_tensorboard: true + log_wandb: false + max_utts: 8 + + encoder: + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel score estimate + depth: 24 + ff_dropout: 0.1 + time_hidden_dim: 1024 + + flow: + _target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow + sigma_start: 1.0 + sigma_end: 1e-4 + + sampler: + _target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler + num_steps: 20 + time_min: 1e-8 + time_max: 1.0 + + ssl_pretrain_masking: + _target_: nemo.collections.audio.modules.ssl_pretrain_masking.SSLPretrainWithMaskedPatch + patch_size: 10 + mask_fraction: 0.7 + + loss: + _target_: nemo.collections.audio.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + estoi: # output ESTOI + _target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility + fs: ${model.sample_rate} + extended: true + pesq: # output PESQ + _target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality + fs: ${model.sample_rate} + mode: wb + + optim: + name: adam + lr: 5e-5 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 5000 + warmup_ratio: null + min_lr: 1e-5 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 10000 # needs to be set for shar datasets + limit_train_batches: 1000 # number of batches to train on in each pseudo-epoch + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + use_distributed_sampler: false # required for lhotse + accumulate_grad_batches: 1 + gradient_clip_val: 0.2 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_pesq + mode: max + save_top_k: 3 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/llm/run/llama3_pretraining.py b/examples/llm/run/llama3_pretraining.py new file mode 100644 index 000000000000..612b58e2169f --- /dev/null +++ b/examples/llm/run/llama3_pretraining.py @@ -0,0 +1,190 @@ +# This script is used for pretraining a Llama3 model, specifically for the 8b or 70b model variants, on local and slurm executors. +# It uses NeMo 2.0 recipes (https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/llm/recipes/llama3_8b.py#L74) and NeMo-Run (https://github.com/NVIDIA/NeMo-Run) to configure and execute the runs. + +import argparse +from functools import partial +from typing import Any, Optional + +import nemo_run as run + +from nemo.collections import llm + + +def get_parser(): + parser = argparse.ArgumentParser(description="Llama3 Pretraining") + parser.add_argument( + "--size", + type=str, + default="8b", + help="Choose llama3 model size 70b/8b", + ) + parser.add_argument( + "--tag", + type=str, + help="Optional tag for your experiment title which will be appended after the model/exp name.", + required=False, + default="", + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Do a dryrun and exit", + default=False, + ) + parser.add_argument( + "--slurm", + action="store_true", + help="Run on slurm using run.SlurmExecutor", + default=False, + ) + return parser + + +def slurm_executor( + user: str, + host: str, + remote_job_dir: str, + account: str, + partition: str, + nodes: int, + devices: int, + time: str = "01:00:00", + custom_mounts: Optional[list[str]] = None, + custom_env_vars: Optional[dict[str, str]] = None, + container_image: str = "nvcr.io/nvidia/nemo:dev", + retries: int = 0, +) -> run.SlurmExecutor: + if not (user and host and remote_job_dir and account and partition and nodes and devices): + raise RuntimeError( + "Please set user, host, remote_job_dir, account, partition, nodes and devices args for using this function." + ) + + mounts = [] + if custom_mounts: + mounts.extend(custom_mounts) + + env_vars = { + "TRANSFORMERS_OFFLINE": "1", + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + "NCCL_NVLS_ENABLE": "0", + "NVTE_DP_AMAX_REDUCE_INTERVAL": "0", + "NVTE_ASYNC_AMAX_REDUCTION": "1", + "NVTE_FUSED_ATTN": "0", + } + if custom_env_vars: + env_vars |= custom_env_vars + + executor = run.SlurmExecutor( + account=account, + partition=partition, + tunnel=run.SSHTunnel( + user=user, + host=host, + job_dir=remote_job_dir, + ), + nodes=nodes, + ntasks_per_node=devices, + gpus_per_node=devices, + mem="0", + exclusive=True, + gres="gpu:8", + packager=run.GitArchivePackager(subpath="examples/llm/run"), + ) + + executor.container_image = container_image + executor.container_mounts = mounts + executor.env_vars = env_vars + executor.retries = retries + executor.time = time + + return executor + + +def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor: + env_vars = { + "TRANSFORMERS_OFFLINE": "1", + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + "NCCL_NVLS_ENABLE": "0", + "NVTE_DP_AMAX_REDUCE_INTERVAL": "0", + "NVTE_ASYNC_AMAX_REDUCTION": "1", + "NVTE_FUSED_ATTN": "0", + } + + executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars) + + return executor + + +def main(): + args = get_parser().parse_args() + if args.tag and not args.tag.startswith("-"): + args.tag = "-" + args.tag + + MODEL_SIZE_MAPPING: dict[str, dict[str, Any]] = { + "8b": { + "exp_name": "llama3-8b", + "nemo": { + "pretrain": partial(llm.llama3_8b.pretrain_recipe, num_nodes=1, num_gpus_per_node=8), + }, + }, + "70b": { + "exp_name": "llama3-70b", + "nemo": { + "pretrain": partial(llm.llama3_70b.pretrain_recipe, num_nodes=128, num_gpus_per_node=8), + }, + }, + } + + exp_name = MODEL_SIZE_MAPPING[args.size]["exp_name"] + + # Uses configs from NeMo directly + pretrain = MODEL_SIZE_MAPPING[args.size]["nemo"]["pretrain"]( + name=exp_name, + ckpt_dir=f"/{exp_name}/checkpoints", + ) + + # Overwrite the dataloader in the recipe to use your custom dataloader. + # dataloader = set_your_custom_dataloader + # pretrain.data = dataloader + + pretrain.trainer.val_check_interval = 400 + pretrain.log.ckpt.save_top_k = -1 + pretrain.log.ckpt.every_n_train_steps = 400 + + pretrain.trainer.max_steps = 1000 + + executor: run.Executor + + if args.slurm: + # TODO: Set your custom parameters for the Slurm Executor. + executor = slurm_executor( + user="", + host="", + remote_job_dir="", + account="", + partition="", + nodes=pretrain.trainer.num_nodes, + devices=pretrain.trainer.devices, + ) + else: + executor = local_executor_torchrun(nodes=pretrain.trainer.num_nodes, devices=pretrain.trainer.devices) + + with run.Experiment(f"{exp_name}{args.tag}") as exp: + pretrain.log.dir = f"/{exp_name}/checkpoints" + + for i in range(1): + exp.add( + pretrain, + executor=executor, + name=exp_name, + tail_logs=True if isinstance(executor, run.LocalExecutor) else False, + ) + + if args.dryrun: + exp.dryrun() + else: + exp.run(sequential=True, detach=True) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/speech_llm/export/README.md b/examples/multimodal/speech_llm/export/README.md new file mode 100644 index 000000000000..05e44d112cce --- /dev/null +++ b/examples/multimodal/speech_llm/export/README.md @@ -0,0 +1,83 @@ +## Setup +In this part, we are going to export SALM model into TRTLLM. +First, let's download the [SALM nemo model](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/speechllm_fc_llama2_7b/) from NVIDIA ngc. + +```bash +wget --content-disposition 'https://api.ngc.nvidia.com/v2/models/org/nvidia/team/nemo/speechllm_fc_llama2_7b/1.23.1/files?redirect=true&path=speechllm_fc_llama2_7b.nemo' -O speechllm_fc_llama2_7b.nemo +``` + +Then, we need to extract the different parts of SALM. +```bash +output=$PWD/output +python3 extract_salm_weights.py --model_file_path=speechllm_fc_llama2_7b.nemo --output_dir=$output +``` +It takes a while to run the above command. + +Under the `output` dir, you'll see: +``` +output + |___speechllm_fc_llama2_7b_lora.nemo + |___speechllm_fc_llama2_7b_perception + | |____model_config.yaml + | |____model_weights.ckpt + |___speechllm_fc_llama2_7b_llm.nemo + |___ xxx.tokenizer.model +``` + +After we get the lora nemo model and llm nemo model, we can merge the lora part into the llm by: +```bash +python /opt/NeMo/scripts/nlp_language_modeling/merge_lora_weights/merge.py \ + trainer.accelerator=gpu \ + tensor_model_parallel_size=1 \ + pipeline_model_parallel_size=1 \ + gpt_model_file=output/speechllm_fc_llama2_7b_llm.nemo \ + lora_model_path=output/speechllm_fc_llama2_7b_lora.nemo \ + merged_model_path=speechllm_fc_llama2_7b_llm_merged.nemo +``` + +Now we are able to export the engine by: +```bash +python3 export_salm.py \ + model.perception_model_path=output/speechllm_fc_llama2_7b_perception \ + model.llm_model_path=output/speechllm_fc_llama2_7b_llm_merged.nemo +``` + +You should be able to get the generated engines under `./salm` folder. To run the engines, you may run: +```python +from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter + +output_dir = "/ws/salm" # the engine directory +trt_llm_exporter = TensorRTMMExporter(model_dir=output_dir, load_model=True, modality='audio') +input_text = "Q: what's the transcription of the audio? A:" +input_media = '/ws/data/test_audio.wav' +print(trt_llm_exporter.forward(input_text, input_media)) + +``` + +## Deploy +If you want to generate the engines and deploy them with Triton Inference Server, you may also run: + +```bash +python3 NeMo/scripts/deploy/multimodal/deploy_triton.py \ + --modality="audio" \ + --visual_checkpoint=NeMo/examples/multimodal/speech_llm/export/output/speechllm_fc_llama2_7b_perception \ + --llm_checkpoint=NeMo/examples/multimodal/speech_llm/export/output/speechllm_fc_llama2_7b_llm_merged.nemo \ + --llm_model_type="llama" \ + --model_type="salm" \ + --triton_model_name="salm" \ + --max_input_len=4096 \ + --max_output_len=256 \ + --max_multimodal_len=3072 \ + --triton_model_repository=/tmp/trt_model_dir/ +``` + +And on client side, you may run: +```bash +python3 NeMo/scripts/deploy/multimodal/query.py \ + --model_name="salm" \ + --model_type="salm" \ + --input_text="Q: what's the transcription of the audio? A:" \ + --input_media=/ws/data/test_audio.wav +``` + +For more details, please check `NeMo/scripts/deploy/multimodal/deploy_triton.py` and ` NeMo/scripts/deploy/multimodal/query.py`. \ No newline at end of file diff --git a/examples/multimodal/speech_llm/export/conf/salm_export.yaml b/examples/multimodal/speech_llm/export/conf/salm_export.yaml new file mode 100644 index 000000000000..54ab6e9180c5 --- /dev/null +++ b/examples/multimodal/speech_llm/export/conf/salm_export.yaml @@ -0,0 +1,16 @@ +name: speechllm_salm +infer: + output_dir: ./salm + max_batch_size: 1 + tensor_parallelism: 1 + max_input_len: 4096 + max_output_len: 256 + max_multimodal_len: 3072 + perception_max_batch_size: 1 + +model: + type: salm + precision: float16 + perception_model_path: /path/to/speechllm_llama2_7b_perception + llm_model_path: /path/to/speechllm_llama2_7b_llm.nemo + llm_model_type: llama diff --git a/examples/multimodal/speech_llm/export/export_salm.py b/examples/multimodal/speech_llm/export/export_salm.py new file mode 100644 index 000000000000..00500bf46f50 --- /dev/null +++ b/examples/multimodal/speech_llm/export/export_salm.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.core.config import hydra_runner +from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter + + +@hydra_runner(config_path='conf', config_name='salm_export') +def main(cfg): + exporter = TensorRTMMExporter(model_dir=cfg.infer.output_dir, load_model=False, modality='audio') + exporter.export( + visual_checkpoint_path=cfg.model.perception_model_path, + llm_checkpoint_path=cfg.model.llm_model_path, + model_type=cfg.model.type, + llm_model_type=cfg.model.llm_model_type, + tensor_parallel_size=cfg.infer.tensor_parallelism, + max_input_len=cfg.infer.max_input_len, + max_output_len=cfg.infer.max_output_len, + vision_max_batch_size=cfg.infer.perception_max_batch_size, + max_batch_size=cfg.infer.max_batch_size, + max_multimodal_len=cfg.infer.max_multimodal_len, + dtype=cfg.model.precision, + load_model=False, + ) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/speech_llm/export/extract_salm_weights.py b/examples/multimodal/speech_llm/export/extract_salm_weights.py new file mode 100644 index 000000000000..0698a411110e --- /dev/null +++ b/examples/multimodal/speech_llm/export/extract_salm_weights.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import tempfile + +import torch +from megatron.core import dist_checkpointing +from omegaconf import OmegaConf +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.multimodal.speech_llm.modules.perception_modules import AudioPerceptionModule +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper +from nemo.utils import logging +from nemo.utils.model_utils import inject_model_parallel_rank + + +def get_config_and_state_dict_from_nemo(filepath, map_location, output_dir, sharded_state_dict=None): + cwd = os.getcwd() + save_restore_connector = NLPSaveRestoreConnector() + + with tempfile.TemporaryDirectory() as tmpdir: + try: + if os.path.isfile(filepath): + save_restore_connector._unpack_nemo_file(path2file=filepath, out_folder=tmpdir) + else: + tmpdir = filepath + + os.chdir(tmpdir) + config_yaml = "model_config.yaml" + model_weights_ckpt = "model_weights.ckpt" + + # find file in tmpdir that endswith "tokenizer.model" + tokenizer = None + for file in os.listdir(tmpdir): + if file.endswith("tokenizer.model"): + tokenizer = file + break + if tokenizer is None: + raise ValueError(f"Tokenizer not found in {tmpdir}") + tokenizer_path = os.path.join(tmpdir, tokenizer) + # copy tokenizer_path to current directory + os.system(f"cp {tokenizer_path} {output_dir}") + tokenizer_path = os.path.join(output_dir, tokenizer) + + # load conf + with open(config_yaml) as f: + conf = OmegaConf.load(f) + + os.chdir(cwd) + model_weights = os.path.join(tmpdir, model_weights_ckpt) + model_weights = inject_model_parallel_rank(model_weights) + state_dict = save_restore_connector._load_state_dict_from_disk(model_weights, map_location=map_location) + + # distributed checkpointing + if state_dict is None and sharded_state_dict is not None: + checkpoint = dict(state_dict=sharded_state_dict) + tmp_model_weights_ckpt = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt) + tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0] + assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.' + checkpoint = dist_checkpointing.load( + sharded_state_dict=checkpoint, + checkpoint_dir=tmp_model_weights_dir, + ) + state_dict = checkpoint["state_dict"] + + conf.tokenizer.model = tokenizer_path + return conf, state_dict + finally: + os.chdir(cwd) + + +def get_llm_model_state_dict(state_dict, lora_model_state_dict): + llm_model_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("model."): + if key not in lora_model_state_dict and value != None: + llm_model_state_dict[key] = value + return llm_model_state_dict + + +def get_lora_state_dict(state_dict): + lora_model_state_dict = {} + for key, value in state_dict.items(): + if "adapter_layer.lora" in key and value != None: + lora_model_state_dict[key] = value + return lora_model_state_dict + + +def get_perception_state_dict(state_dict): + perception_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("perception."): + key = key.replace("perception.", "", 1) + perception_state_dict[key] = value + return perception_state_dict + + +def save_llm_model(state_dict, nemo_config, output_path): + if nemo_config.get('megatron_amp_O2', False): + keys = list(state_dict.keys()) + for key in keys: + state_dict[key.replace('model.', 'model.module.', 1)] = state_dict['state_dict'].pop(key) + + trainer = Trainer(accelerator='cpu', strategy=NLPDDPStrategy()) + model = load_state_dict_helper(MegatronGPTModel, nemo_config, trainer, state_dict) + model._save_restore_connector = NLPSaveRestoreConnector() + model.cfg.use_cpu_initialization = False + + model.save_to(output_path) + logging.info(f'llm model saved to: {output_path}') + + +def save_nemo_weights(state_dict, output_dir, config, save_nemo_model=True): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + weight_file = os.path.join(output_dir, "model_weights.ckpt") + torch.save(state_dict, weight_file) + # convert config to yaml + config_file = os.path.join(output_dir, "model_config.yaml") + with open(config_file, "w") as f: + f.write(OmegaConf.to_yaml(config)) + + if save_nemo_model: + # create nemo file + nemo_model_name = f"{output_dir}.nemo" + nemo_path = os.path.join(output_dir, nemo_model_name) + # tar model_config.yaml and model_weights.ckpt + os.system(f"tar -C {output_dir} -cvf {nemo_path} model_config.yaml model_weights.ckpt") + # remove model_config.yaml and model_weights.ckpt + os.system(f"rm {config_file} {weight_file}") + # remove the empty directory + os.system(f"rmdir {output_dir}") + + +def separate_speechllm_model(model_file_path, output_dir, map_location="cuda:0"): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + output_dir = os.path.abspath(output_dir) + + logging.info(f"Separating {model_file_path} into perception, lora, and llm model") + filepath = model_file_path + conf, state_dict = get_config_and_state_dict_from_nemo(filepath, map_location, output_dir) + + base_model_name = os.path.basename(filepath).split(".")[0] + + perception_state_dict = get_perception_state_dict(state_dict) + perception_model_dir = None + if perception_state_dict: + perception_model_dir = f"{base_model_name}_perception" + perception_model_dir = os.path.join(output_dir, perception_model_dir) + save_nemo_weights(perception_state_dict, perception_model_dir, conf.perception, save_nemo_model=False) + + # verify if the exported perception model is correct + perception = AudioPerceptionModule(cfg=conf.perception) + perception.load_state_dict(perception_state_dict) + perception.eval() + print(perception) + print(perception(input_signal=torch.randn(1, 1000), input_signal_length=torch.tensor([1000]))) + # absolute path of perception model + logging.info(f"Perception model saved to: {perception_model_dir}") + + lora_model_weights = get_lora_state_dict(state_dict) + lora_model_dir = None + if lora_model_weights: + lora_model_dir = f"{base_model_name}_lora" + lora_model_dir = os.path.join(output_dir, lora_model_dir) + save_nemo_weights(lora_model_weights, lora_model_dir, conf) + logging.info(f"Lora model saved to: {lora_model_dir}.nemo") + # hard code the target model for now + llm_model_weights = get_llm_model_state_dict(state_dict, lora_model_weights) + if llm_model_weights: + llm_model = f"{base_model_name}_llm.nemo" + llm_model = os.path.join(output_dir, llm_model) + conf.target = "nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel" + save_llm_model(llm_model_weights, conf, llm_model) + logging.info(f"LLM model saved to: {llm_model}") + + +# filepath = "/ws/speechllm_fc_llama2_7b.nemo" +# output_dir = "/ws/speechllm_fc_llama2_7b_separated" +# perception_model_dir, lora_model, llm_model = separate_speechllm_model(filepath, output_dir) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Separate speechllm model') + parser.add_argument('--model_file_path', type=str, help='Path to the speechllm model') + parser.add_argument('--output_dir', type=str, help='Output directory to save the separated models') + args = parser.parse_args() + separate_speechllm_model(args.model_file_path, args.output_dir) diff --git a/examples/nlp/duplex_text_normalization/data/en/data_preprocessing.py b/examples/nlp/duplex_text_normalization/data/en/data_preprocessing.py index 9523d0974db8..f902e771cde4 100644 --- a/examples/nlp/duplex_text_normalization/data/en/data_preprocessing.py +++ b/examples/nlp/duplex_text_normalization/data/en/data_preprocessing.py @@ -46,8 +46,8 @@ import os from argparse import ArgumentParser +from functools import cache -import inflect import regex as re from tqdm import tqdm @@ -60,12 +60,21 @@ ) from nemo.utils import logging -engine = inflect.engine() + +@cache +def inflect_engine(): + import inflect + + return inflect.engine() + # these are all words that can appear in a verbalized number, this list will be used later as a filter to detect numbers in verbalizations number_verbalizations = list(range(0, 20)) + list(range(20, 100, 10)) number_verbalizations = ( - [engine.number_to_words(x, zero="zero").replace("-", " ").replace(",", "") for x in number_verbalizations] + [ + inflect_engine().number_to_words(x, zero="zero").replace("-", " ").replace(",", "") + for x in number_verbalizations + ] + ["hundred", "thousand", "million", "billion", "trillion"] + ["point"] ) @@ -85,7 +94,7 @@ def process_url(o): """ def flatten(l): - """ flatten a list of lists """ + """flatten a list of lists""" return [item for sublist in l for item in sublist] if o != '' and '_letter' in o: @@ -129,6 +138,7 @@ def convert2digits(digits: str): Return: res: number verbalization of the integer prefix of the input """ + engine = inflect_engine() res = [] for i, x in enumerate(digits): if x in digit: @@ -145,6 +155,7 @@ def convert2digits(digits: str): def convert(example): + engine = inflect_engine() cls, written, spoken = example written = convert_fraction(written) @@ -288,7 +299,7 @@ def convert(example): def ignore(example): """ This function makes sure specific class types like 'PLAIN', 'ELECTRONIC' etc. are left unchanged. - + Args: example: data example """ @@ -300,7 +311,7 @@ def ignore(example): def process_file(fp): - """ Reading the raw data from a file of NeMo format and preprocesses it. Write is out to the output directory. + """Reading the raw data from a file of NeMo format and preprocesses it. Write is out to the output directory. For more info about the data format, refer to the `text_normalization doc `. diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml index f603ebb58eb7..62f0e452d3b5 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml @@ -17,13 +17,15 @@ trainer: num_nodes: 1 accelerator: gpu logger: false # logger provided by exp_manager - precision: bf16 # 16, 32, or bf16 + precision: ${export.dtype} # 16, bf16, or 32 enable_checkpointing: false model: tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 restore_from_path: llama2-7b-fp16.nemo # Nemo file path + precision: ${export.dtype} # Model weights data type + megatron_amp_O2: true # Enable Megatron O2-style half-precision ## Activation Checkpoint activations_checkpoint_granularity: null # 'selective' or 'full' @@ -42,7 +44,7 @@ export: decoder_type: llama # gptnext, gpt2, llama inference_tensor_parallel: 1 # Default using 1 TP for inference inference_pipeline_parallel: 1 # Default using 1 PP for inference - dtype: ${trainer.precision} # Default precision data type + dtype: 16 # Default precision data type for non-quantized layers: 16 or bf16 save_path: llama2-7b-${quantization.algorithm}.qnemo # Path where the quantized model will be saved compress: false # Whether save_path should be a tarball or a directory sample_output: true # Whether to run a sample prompt before saving diff --git a/examples/tts/conf/audio_codec/audio_codec_22050.yaml b/examples/tts/conf/audio_codec/audio_codec_22050.yaml new file mode 100644 index 000000000000..c45f2c2a129c --- /dev/null +++ b/examples/tts/conf/audio_codec/audio_codec_22050.yaml @@ -0,0 +1,193 @@ +# This config contains the default values for training 22.05kHz NeMo Audio Codec model. +# If you want to train model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +name: AudioCodec + +max_epochs: ??? +# Adjust batch size based on GPU memory +batch_size: 16 +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. +weighted_sampling_steps_per_epoch: null + +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +train_ds_meta: ??? +val_ds_meta: ??? + +log_ds_meta: ??? +log_dir: ??? + +# Modify these values based on your sample rate +sample_rate: 22050 +win_length: 1024 +hop_length: 256 +train_n_samples: 8192 # ~0.37 seconds +# The product of the down_sample_rates and up_sample_rates should match the hop_length. +# For example 2 * 2 * 8 * 8 = 256. +down_sample_rates: [2, 2, 8, 8] +up_sample_rates: [8, 8, 2, 2] + +num_codebooks: 8 +encoder_out_dim: 32 + +model: + + max_epochs: ${max_epochs} + steps_per_epoch: ${weighted_sampling_steps_per_epoch} + + sample_rate: ${sample_rate} + samples_per_frame: ${hop_length} + + mel_loss_l1_scale: 10.0 + mel_loss_l2_scale: 0.0 + stft_loss_scale: 10.0 + time_domain_loss_scale: 0.0 + commit_loss_scale: 0.0 + + # Probability of updating the discriminator during each training step + # For example, update the discriminator 1/2 times (1 update for every 2 batches) + disc_updates_per_period: 1 + disc_update_period: 2 + + # All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] + loss_resolutions: [ + [32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024] + ] + mel_loss_dims: [5, 10, 20, 40, 80, 160] + mel_loss_log_guard: 1.0 + stft_loss_log_guard: 1.0 + feature_loss_type: absolute + + train_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + dataset_meta: ${train_ds_meta} + weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} + sample_rate: ${sample_rate} + n_samples: ${train_n_samples} + min_duration: 0.4 # seconds + max_duration: null + + dataloader_params: + batch_size: ${batch_size} + drop_last: true + num_workers: 4 + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss + dataset_meta: ${val_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + # Configures how audio samples are generated and saved during training. + # Remove this section to disable logging. + log_config: + log_dir: ${log_dir} + log_epochs: [10, 50] + epoch_frequency: 100 + log_tensorboard: false + log_wandb: false + + generators: + - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator + log_audio: true + log_encoding: false + log_dequantized: false + + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only log the first 10 seconds of generated audio. + dataset_meta: ${log_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + audio_encoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANEncoder + down_sample_rates: ${down_sample_rates} + encoded_dim: ${encoder_out_dim} + base_channels: 48 + activation: "lrelu" + + audio_decoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder + up_sample_rates: ${up_sample_rates} + input_dim: ${encoder_out_dim} + base_channels: 768 + activation: "half_snake" + output_activation: "clamp" + + vector_quantizer: + _target_: nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer + num_groups: ${num_codebooks} + num_levels_per_group: [8, 5, 5, 5] + + discriminator: + _target_: nemo.collections.tts.modules.audio_codec_modules.Discriminator + discriminators: + - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator + - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT + resolutions: [[512, 128, 512], [1024, 256, 1024]] + stft_bands: [[0.0, 0.1], [0.1, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]] + + generator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss + + discriminator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss + + optim: + _target_: torch.optim.Adam + lr: 2e-4 + betas: [0.8, 0.99] + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: 16 + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 10 + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: false + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/tts/conf/audio_codec/audio_codec_44100.yaml b/examples/tts/conf/audio_codec/audio_codec_44100.yaml new file mode 100644 index 000000000000..eab13a0e440b --- /dev/null +++ b/examples/tts/conf/audio_codec/audio_codec_44100.yaml @@ -0,0 +1,193 @@ +# This config contains the default values for training 44.1kHz NeMo Audio Codec model. +# If you want to train model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +name: AudioCodec + +max_epochs: ??? +# Adjust batch size based on GPU memory +batch_size: 16 +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. +weighted_sampling_steps_per_epoch: null + +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +train_ds_meta: ??? +val_ds_meta: ??? + +log_ds_meta: ??? +log_dir: ??? + +# Modify these values based on your sample rate +sample_rate: 44100 +win_length: 2048 +hop_length: 512 +train_n_samples: 16384 # ~0.37 seconds +# The product of the down_sample_rates and up_sample_rates should match the hop_length. +# For example 2 * 4 * 8 * 8 = 512. +down_sample_rates: [2, 4, 8, 8] +up_sample_rates: [8, 8, 4, 2] + +num_codebooks: 8 +encoder_out_dim: 32 + +model: + + max_epochs: ${max_epochs} + steps_per_epoch: ${weighted_sampling_steps_per_epoch} + + sample_rate: ${sample_rate} + samples_per_frame: ${hop_length} + + mel_loss_l1_scale: 10.0 + mel_loss_l2_scale: 0.0 + stft_loss_scale: 10.0 + time_domain_loss_scale: 0.0 + commit_loss_scale: 0.0 + + # Probability of updating the discriminator during each training step + # For example, update the discriminator 1/2 times (1 update for every 2 batches) + disc_updates_per_period: 1 + disc_update_period: 2 + + # All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] + loss_resolutions: [ + [32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048] + ] + mel_loss_dims: [5, 10, 20, 40, 80, 160, 320] + mel_loss_log_guard: 1.0 + stft_loss_log_guard: 1.0 + feature_loss_type: absolute + + train_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + dataset_meta: ${train_ds_meta} + weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} + sample_rate: ${sample_rate} + n_samples: ${train_n_samples} + min_duration: 0.4 # seconds + max_duration: null + + dataloader_params: + batch_size: ${batch_size} + drop_last: true + num_workers: 4 + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss + dataset_meta: ${val_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + # Configures how audio samples are generated and saved during training. + # Remove this section to disable logging. + log_config: + log_dir: ${log_dir} + log_epochs: [10, 50] + epoch_frequency: 100 + log_tensorboard: false + log_wandb: false + + generators: + - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator + log_audio: true + log_encoding: false + log_dequantized: false + + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only log the first 10 seconds of generated audio. + dataset_meta: ${log_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + audio_encoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANEncoder + down_sample_rates: ${down_sample_rates} + encoded_dim: ${encoder_out_dim} + base_channels: 48 + activation: "lrelu" + + audio_decoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder + up_sample_rates: ${up_sample_rates} + input_dim: ${encoder_out_dim} + base_channels: 768 + activation: "half_snake" + output_activation: "clamp" + + vector_quantizer: + _target_: nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer + num_groups: ${num_codebooks} + num_levels_per_group: [8, 5, 5, 5] + + discriminator: + _target_: nemo.collections.tts.modules.audio_codec_modules.Discriminator + discriminators: + - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator + - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT + resolutions: [[512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] + stft_bands: [[0.0, 0.1], [0.1, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]] + + generator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss + + discriminator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss + + optim: + _target_: torch.optim.Adam + lr: 2e-4 + betas: [0.8, 0.99] + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: 16 + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 10 + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: false + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/nemo/collections/asr/parts/k2/graph_decoders.py b/nemo/collections/asr/parts/k2/graph_decoders.py index 33218588b79f..981025e7c418 100644 --- a/nemo/collections/asr/parts/k2/graph_decoders.py +++ b/nemo/collections/asr/parts/k2/graph_decoders.py @@ -13,14 +13,28 @@ # limitations under the License. from abc import abstractmethod +from collections import defaultdict +from pathlib import Path from typing import List, Optional, Tuple, Union import torch +from jiwer import wer as word_error_rate from omegaconf import DictConfig from nemo.collections.asr.parts.k2.classes import GraphIntersectDenseConfig from nemo.collections.asr.parts.k2.loss_mixins import CtcK2Mixin, RnntK2Mixin -from nemo.collections.asr.parts.k2.utils import invert_permutation, load_graph +from nemo.collections.asr.parts.k2.utils import ( + create_supervision, + invert_permutation, + levenshtein_graph_k2, + load_graph, +) +from nemo.collections.asr.parts.submodules.wfst_decoder import ( + AbstractWFSTDecoder, + WfstNbestHypothesis, + collapse_tokenword_hypotheses, +) +from nemo.core.utils.k2_guard import k2 from nemo.utils import logging @@ -121,7 +135,8 @@ def _decode_impl( return lats else: shortest_path_fsas = k2.index_fsa( - k2.shortest_path(lats, True), invert_permutation(order).to(device=log_probs.device), + k2.shortest_path(lats, True), + invert_permutation(order).to(device=log_probs.device), ) return self._extract_labels_and_probabilities(shortest_path_fsas, return_ilabels, output_aligned) @@ -336,3 +351,336 @@ def update_graph(self, graph: 'k2.Fsa'): self.num_classes, self.blank, self.topo_type, self.topo_with_self_loops, self.device, token_lm ) self.base_graph = k2.create_fsa_vec([self.graph_compiler.base_graph]).to(self.device) + + +class K2WfstDecoder(AbstractWFSTDecoder): + """ + Used for performing WFST decoding of the logprobs with the k2 WFST decoder. + + Args: + lm_fst: + Kaldi-type language model WFST or its path. + + decoding_mode: + Decoding mode. Choices: `nbest`, `lattice`. + + beam_size: + Beam width (float) for the WFST decoding. + + config: + Riva Decoder config. + + tokenword_disambig_id: + Tokenword disambiguation index. Set to -1 to disable the tokenword mode. + + lm_weight: + Language model weight in decoding. + + nbest_size: + N-best size for decoding_mode == `nbest` + + device: + Device for running decoding. Choices: `cuda`, `cpu`. + """ + + def __init__( + self, + lm_fst: Union['k2.Fsa', Path, str], + decoding_mode: str = 'nbest', + beam_size: float = 10.0, + config: Optional[GraphIntersectDenseConfig] = None, + tokenword_disambig_id: int = -1, + lm_weight: float = 1.0, + nbest_size: int = 1, + device: str = "cuda", + ): + self._nbest_size = nbest_size + self._device = device + super().__init__(lm_fst, decoding_mode, beam_size, config, tokenword_disambig_id, lm_weight) + + def _set_decoder_config(self, config: Optional[GraphIntersectDenseConfig] = None): + if config is None: + config = GraphIntersectDenseConfig() + config.search_beam = 20.0 + config.output_beam = self._beam_size + config.max_active_states = 10000 + self._config = config + + def _set_decoding_mode(self, decoding_mode: str): + if decoding_mode not in ('nbest', 'lattice'): + raise ValueError(f"Unsupported mode: {decoding_mode}") + self._decoding_mode = decoding_mode + + @torch.inference_mode(False) + def _init_decoder(self): + lm_fst = load_graph(self._lm_fst) if isinstance(self._lm_fst, (Path, str)) else self._lm_fst.clone() + lm_fst.lm_scores = lm_fst.scores.clone() + self._lm_fst = lm_fst.to(device=self._device) + + if self._id2word is None: + self._id2word = { + int(line.split()[1]): line.split()[0] + for line in self._lm_fst.aux_labels_sym.to_str().strip().split("\n") + } + word2id = self._id2word.__class__(map(reversed, self._id2word.items())) + word_unk_id = word2id[""] + self._word2id = defaultdict(lambda: word_unk_id) + for k, v in word2id.items(): + self._word2id[k] = v + if self._id2token is None: + self._id2token = { + int(line.split()[1]): line.split()[0] for line in self._lm_fst.labels_sym.to_str().strip().split("\n") + } + token2id = self._id2token.__class__(map(reversed, self._id2token.items())) + token_unk_id = token2id[""] + self._token2id = defaultdict(lambda: token_unk_id) + for k, v in token2id.items(): + self._token2id[k] = v + + def _beam_size_setter(self, value: float): + if self._beam_size != value: + self._config.output_beam = value + self._beam_size = value + + def _lm_weight_setter(self, value: float): + if self._lm_weight != value: + self._lm_weight = value + + @property + def nbest_size(self): + return self._nbest_size + + @nbest_size.setter + def nbest_size(self, value: float): + self._nbest_size_setter(value) + + def _nbest_size_setter(self, value: float): + if self._nbest_size != value: + self._nbest_size = value + + def _decoding_mode_setter(self, value: str): + if self._decoding_mode != value: + self._set_decoding_mode(value) + + @torch.inference_mode(False) + def _decode_lattice(self, emissions_fsas: 'k2.DenseFsaVec', order: torch.Tensor) -> 'k2.Fsa': + """ + Decodes logprobs into k2-type lattices. + + Args: + emissions_fsas: + A k2.DenseFsaVec of the predicted log-probabilities. + order: + A torch.Tensor that stores the order of the emissions_fsas elements. + + Returns: + k2-type FsaVec. + """ + lats = k2.intersect_dense_pruned( + a_fsas=self._lm_fst, + b_fsas=emissions_fsas, + search_beam=self._config.search_beam, + output_beam=self._config.output_beam, + min_active_states=self._config.min_active_states, + max_active_states=self._config.max_active_states, + frame_idx_name="frame_idx", + allow_partial=True, + ) + lats = k2.connect(k2.expand_ragged_attributes(lats)) + lats.am_scores = lats.scores - lats.lm_scores + if self._lm_weight != 1.0: + lats.scores = lats.am_scores + self._lm_weight * lats.lm_scores + # just in case + lats.__dict__["_properties"] = None + return k2.index_fsa(lats, invert_permutation(order).to(device=self._device)) + + @torch.inference_mode(False) + def decode( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor + ) -> Union[List[WfstNbestHypothesis], List['k2.Fsa']]: + """ + Decodes logprobs into recognition hypotheses. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + Returns: + List of recognition hypotheses. + """ + supervisions = create_supervision(log_probs_length) + order = supervisions[:, 0] + emissions_fsas = k2.DenseFsaVec(log_probs.to(device=self._device), supervisions) + lats = self._decode_lattice(emissions_fsas, order) + hypotheses = self._post_decode(lats) + return hypotheses + + @torch.inference_mode(False) + def _post_decode(self, hypotheses: 'k2.Fsa') -> Union[List[WfstNbestHypothesis], List['k2.Fsa']]: + """ + Does various post-processing of the recognition hypotheses. + + Args: + hypotheses: + FsaVec of k2-type lattices. + + Returns: + List of processed recognition hypotheses. + """ + if self._decoding_mode == 'nbest': + hypotheses_fsa = hypotheses + hypotheses = [] + if self._nbest_size == 1: + shortest_path_fsas = k2.shortest_path(hypotheses_fsa, True) + scores = shortest_path_fsas.get_tot_scores(True, False).tolist() + # direct iterating does not work as expected + for i in range(shortest_path_fsas.shape[0]): + fsa = shortest_path_fsas[i] + non_eps_mask = fsa.aux_labels > 0 + words = [self._id2word[l] for l in fsa.aux_labels[non_eps_mask].tolist()] + alignment = fsa.labels[fsa.labels > 0].tolist() + # some timesteps may be 0 if self.open_vocabulary_decoding + timesteps = fsa.frame_idx[non_eps_mask] + timesteps_left = timesteps[:-1] + timesteps_right = timesteps[1:] + timesteps_right_zero_mask = timesteps_right == 0 + timesteps_right[timesteps_right_zero_mask] = timesteps_left[timesteps_right_zero_mask] + timesteps[1:] = timesteps_right + timesteps = timesteps.tolist() + hypotheses.append( + WfstNbestHypothesis( + tuple( + [ + tuple([tuple(words), tuple(timesteps), tuple(alignment), -scores[i]]), + ] + ) + ) + ) + else: + nbest_fsas = k2.Nbest.from_lattice(hypotheses_fsa, self._nbest_size) + nbest_fsas.fsa.frame_idx = k2.index_select(hypotheses_fsa.frame_idx, nbest_fsas.kept_path.values) + scores = nbest_fsas.fsa.get_tot_scores(True, False).tolist() + nbest_hypothesis_list = [[] for _ in range(nbest_fsas.shape.dim0)] + for i, j in enumerate(nbest_fsas.shape.row_ids(1)): + fsa = nbest_fsas.fsa[i] + non_eps_mask = fsa.aux_labels > 0 + words = [self._id2word[l] for l in fsa.aux_labels[non_eps_mask].tolist()] + alignment = fsa.labels[fsa.labels > 0].tolist() + # some timesteps may be 0 if self.open_vocabulary_decoding + timesteps = fsa.frame_idx[non_eps_mask] + timesteps_left = timesteps[:-1] + timesteps_right = timesteps[1:] + timesteps_right_zero_mask = timesteps_right == 0 + timesteps_right[timesteps_right_zero_mask] = timesteps_left[timesteps_right_zero_mask] + timesteps[1:] = timesteps_right + timesteps = timesteps.tolist() + nbest_hypothesis_list[j].append( + tuple([tuple(words), tuple(timesteps), tuple(alignment), -scores[i]]) + ) + for nbest_hypothesis in nbest_hypothesis_list: + hypotheses.append(WfstNbestHypothesis(tuple(nbest_hypothesis))) + return ( + collapse_tokenword_hypotheses(hypotheses, self._id2word[self._tokenword_disambig_id]) + if self._open_vocabulary_decoding + else hypotheses + ) + else: + return [hypotheses[i].to(device="cpu") for i in range(len(hypotheses))] + + @torch.inference_mode(False) + def calibrate_lm_weight( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, reference_texts: List[str] + ) -> Tuple[float, float]: + """ + Calibrates LM weight to achieve the best WER for given logprob-text pairs. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + reference_texts: + List of reference word sequences. + + Returns: + Pair of (best_lm_weight, best_wer). + """ + assert len(log_probs) == len(reference_texts) + decoding_mode_backup = self.decoding_mode + lm_weight_backup = self.lm_weight + nbest_size_backup = self.nbest_size + self.decoding_mode = "lattice" + lattices = self.decode(log_probs, log_probs_length) + best_lm_weight, best_wer = -1.0, float('inf') + self.decoding_mode = "nbest" + self.nbest_size = 1 + for lm_weight in range(1, 21): # enough for most cases + lm_weight_act = lm_weight / 10 + for lat in lattices: + lat.scores = lat.am_scores + lm_weight_act * lat.lm_scores + hypotheses = self._post_decode(lattices) + wer = word_error_rate([" ".join(h[0].words) for h in hypotheses], reference_texts) + if wer < best_wer: + best_lm_weight, best_wer = lm_weight_act, wer + self.nbest_size = nbest_size_backup + self.decoding_mode = decoding_mode_backup + self.lm_weight = lm_weight_backup + return best_lm_weight, best_wer + + @torch.inference_mode(False) + def calculate_oracle_wer( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, reference_texts: List[str] + ) -> Tuple[float, List[float]]: + """ + Calculates the oracle (the best possible WER for given logprob-text pairs. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + reference_texts: + List of reference word sequences. + + Returns: + Pair of (oracle_wer, oracle_wer_per_utterance). + """ + if self._open_vocabulary_decoding: + raise NotImplementedError + assert len(log_probs) == len(reference_texts) + word_ids = [[self._word2id[w] for w in text.split()] for text in reference_texts] + counts = torch.tensor([len(wid) for wid in word_ids]) + decoding_mode_backup = self.decoding_mode + self.decoding_mode = "lattice" + lattices = self.decode(log_probs, log_probs_length) + oracle_disambig = max(self._id2word.keys()) + 1 + lattices.aux_labels[lattices.aux_labels == 0] = oracle_disambig + lattices = lattices.invert() + delattr(lattices, 'aux_labels') + hyps = levenshtein_graph_k2(lattices).invert() + refs = levenshtein_graph_k2(k2.linear_fsa(word_ids)) + refs, arc_map = k2.add_epsilon_self_loops(refs, ret_arc_map=True) + labels = refs.labels.clone() + labels[arc_map == -1] = oracle_disambig + refs.labels = labels + refs.__dict__["_properties"] = None + refs = k2.arc_sort(refs) + ali_lats = k2.compose(hyps, refs, treat_epsilons_specially=False) + ali_lats = k2.remove_epsilon_self_loops(ali_lats) + # TODO: find out why it fails for some utterances + try: + alignment = k2.shortest_path(ali_lats, use_double_scores=True) + except RuntimeError as e: + logging.warning("calculate_oracle_wer failed") + return -1.0, [] + scores = -alignment.get_tot_scores(True, True).to(dtype=torch.int64) + wer_per_utt = scores / counts + self.decoding_mode = decoding_mode_backup + return (scores.sum() / counts.sum()).item(), wer_per_utt.tolist() diff --git a/nemo/collections/asr/parts/k2/utils.py b/nemo/collections/asr/parts/k2/utils.py index f55620a81356..eca2b2379b43 100644 --- a/nemo/collections/asr/parts/k2/utils.py +++ b/nemo/collections/asr/parts/k2/utils.py @@ -42,7 +42,12 @@ def create_supervision(input_lengths: torch.Tensor) -> torch.Tensor: These supervisions are required for some k2 methods. """ supervisions = torch.stack( - (torch.tensor(range(input_lengths.shape[0])), torch.zeros(input_lengths.shape[0]), input_lengths.cpu(),), 1, + ( + torch.tensor(range(input_lengths.shape[0])), + torch.zeros(input_lengths.shape[0]), + input_lengths.cpu(), + ), + 1, ).to(dtype=torch.int32) # the duration column has to be sorted in decreasing order return supervisions[torch.argsort(supervisions[:, -1], descending=True)] @@ -50,7 +55,7 @@ def create_supervision(input_lengths: torch.Tensor) -> torch.Tensor: def invert_permutation(indices: torch.Tensor) -> torch.Tensor: """Produces a tensor of reverse permutation for a given indices. - + Based on https://github.com/k2-fsa/snowfall/blob/master/snowfall/common.py """ ans = torch.zeros(indices.shape, device=indices.device, dtype=indices.dtype) @@ -59,8 +64,7 @@ def invert_permutation(indices: torch.Tensor) -> torch.Tensor: def make_non_pad_mask(input_lengths: torch.Tensor, seq_len: int): - """Converts input_lengths to a non-padding mask. The mask is 2D. - """ + """Converts input_lengths to a non-padding mask. The mask is 2D.""" batch_size = input_lengths.shape[0] seq_range = torch.arange(0, seq_len, device=input_lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, seq_len) @@ -72,8 +76,7 @@ def make_non_pad_mask(input_lengths: torch.Tensor, seq_len: int): def make_non_pad_mask_3d( lengths_x: torch.Tensor, lengths_y: torch.Tensor, max_length_x: int, max_length_y: int ) -> torch.Tensor: - """Converts two orthogonal input_lengths to a non-padding mask. The mask is 3D. - """ + """Converts two orthogonal input_lengths to a non-padding mask. The mask is 3D.""" assert lengths_x.size() == lengths_y.size() return make_non_pad_mask(lengths_x, max_length_x).unsqueeze(2) & make_non_pad_mask( lengths_y, max_length_y @@ -81,8 +84,7 @@ def make_non_pad_mask_3d( def ragged_to_tensor_2axes_simple(rt: k2.RaggedTensor) -> Optional[torch.Tensor]: - """Converts k2.RaggedTensor to torch.Tensor if the RaggedTensor is shallow (has two axes). - """ + """Converts k2.RaggedTensor to torch.Tensor if the RaggedTensor is shallow (has two axes).""" rt_list = rt.tolist() result_list = [] for e in rt_list: @@ -96,8 +98,7 @@ def ragged_to_tensor_2axes_simple(rt: k2.RaggedTensor) -> Optional[torch.Tensor] def load_graph(graph_path: str) -> 'k2.Fsa': - """Fsa graph loading helper function. Loads graphs stored in different formats. - """ + """Fsa graph loading helper function. Loads graphs stored in different formats.""" if os.path.exists(graph_path): errors = [] try: @@ -122,8 +123,7 @@ def load_graph(graph_path: str) -> 'k2.Fsa': def intersect_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa': - """Intersection helper function. - """ + """Intersection helper function.""" assert hasattr(base_graph, "aux_labels") assert not hasattr(aux_graph, "aux_labels") aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device) @@ -133,8 +133,7 @@ def intersect_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2. def compose_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa': - """Composition helper function. - """ + """Composition helper function.""" aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device) return k2.compose(base_graph, aux_graph_with_self_loops, treat_epsilons_specially=False, inner_labels="phones") @@ -145,13 +144,16 @@ def create_sparse_wrapped( size: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, min_col_index: Optional[int] = None, ) -> torch.Tensor: - """Wraps up k2.create_sparse to create 2- or 3-dimensional sparse tensors. - """ + """Wraps up k2.create_sparse to create 2- or 3-dimensional sparse tensors.""" assert size is None or len(indices) == len(size) if len(indices) == 2: return k2.create_sparse( - rows=indices[0], cols=indices[1], values=values, size=size, min_col_index=min_col_index, + rows=indices[0], + cols=indices[1], + values=values, + size=size, + min_col_index=min_col_index, ) elif len(indices) == 3: assert indices[0].ndim == indices[1].ndim == indices[2].ndim == 1 @@ -164,28 +166,43 @@ def create_sparse_wrapped( values = values[kept_indices] if size is not None: return torch.sparse_coo_tensor( - torch.stack(indices), values, size=size, device=values.device, requires_grad=values.requires_grad, + torch.stack(indices), + values, + size=size, + device=values.device, + requires_grad=values.requires_grad, ) else: return torch.sparse_coo_tensor( - torch.stack(indices), values, device=values.device, requires_grad=values.requires_grad, + torch.stack(indices), + values, + device=values.device, + requires_grad=values.requires_grad, ) else: raise ValueError(f"len(indices) = {len(indices)}") def prep_padded_densefsavec(log_softmax: torch.Tensor, supervisions: torch.Tensor) -> 'k2.DenseFsaVec': - """Performs special epsilon-padding required for composition with some of the topologies. - """ + """Performs special epsilon-padding required for composition with some of the topologies.""" log_softmax_eps = torch.cat( [ log_softmax, - torch.full((log_softmax.shape[0], log_softmax.shape[1], 1), -float("inf"), device=log_softmax.device,), + torch.full( + (log_softmax.shape[0], log_softmax.shape[1], 1), + -float("inf"), + device=log_softmax.device, + ), ], axis=-1, ) log_softmax_padded = torch.zeros( - (log_softmax_eps.shape[0], log_softmax_eps.shape[1] * 2, log_softmax_eps.shape[2],), device=log_softmax.device, + ( + log_softmax_eps.shape[0], + log_softmax_eps.shape[1] * 2, + log_softmax_eps.shape[2], + ), + device=log_softmax.device, ) log_softmax_padded[:, ::2] = log_softmax_eps supervisions_padded = supervisions.clone() @@ -235,8 +252,7 @@ def add_self_loops(graph: 'k2.Fsa', label: int = 0, mode: str = "auto"): def get_arc_weights(graph: 'k2.Fsa') -> torch.Tensor: - """Returns 1d torch.Tensor with arc weights of a given graph. - """ + """Returns 1d torch.Tensor with arc weights of a given graph.""" if len(graph.shape) > 2: raise NotImplementedError("FsaVec is not supported at the moment.") weights_int = graph.arcs.values()[:, -1].tolist() @@ -254,7 +270,7 @@ def get_tot_objf_and_finite_mask(tot_scores: torch.Tensor, reduction: str) -> Tu Returns: Returns a tuple of 2 scalar tensors: (tot_score, finite_mask) where finite_mask is a tensor containing successful segment mask. - + Based on get_tot_objf_and_num_frames from https://github.com/k2-fsa/snowfall/blob/master/snowfall/objectives/common.py """ @@ -324,3 +340,53 @@ def apply_rnnt_prune_ranges( index=ranges.reshape((B, T, window_size_with_blank, 1)).expand((B, T, window_size_with_blank, D2)), ) return encoder_outputs_pruned, decoder_outputs_pruned + + +def levenshtein_graph_k2(fsa: 'k2.Fsa', ins_del_score: float = -0.501) -> 'k2.Fsa': + """Construct the levenshtein graph from a k2-type WFST or a lattice. + + See also levenshtein_graph from k2. + + Args: + fst: + K2-type source WFST or lattice. + + ins_del_score: + Insertion and deletion penalty. + Should be more than 0.5 for substitutions to be preferred over insertions/deletions, or less otherwise. + + Returns: + K2-type levenshtein WFST. + """ + sub_score = -0.5 + sub_score_int = struct.unpack('@i', struct.pack('@f', sub_score))[0] + arcs = fsa.arcs.values() + final_indices = (fsa.labels == -1).nonzero() + template_mask = ~torch.zeros(len(arcs) * 2, dtype=bool) + no_duplicate_final_mask = template_mask.clone() + no_duplicate_final_mask[final_indices * 2 + 1] = False + new_mask = ~template_mask + new_mask[1::2] = True + new_mask = new_mask[no_duplicate_final_mask] + duplicate_indices = torch.arange(len(arcs)).repeat_interleave(2)[no_duplicate_final_mask] + new_arcs = arcs[duplicate_indices] + new_arcs[:, -1] = torch.where(new_mask, sub_score_int, 0) + if len(fsa.shape) == 3: + new_shape, _ = fsa.arcs.shape().index(2, duplicate_indices.to(dtype=torch.int32)) + # apparently k2 does not support indexing RaggedArc with RaggedShape + new_splits = new_shape.row_splits(2)[new_shape.row_splits(1)] + levenshtein_fsa = k2.create_fsa_vec([k2.Fsa(new_arcs[i:j]) for i, j in zip(new_splits[:-1], new_splits[1:])]) + else: + levenshtein_fsa = k2.Fsa(new_arcs) + levenshtein_fsa.aux_labels = levenshtein_fsa.labels.clone() + labels = levenshtein_fsa.labels.clone() + labels[new_mask] = 0 + levenshtein_fsa.labels = labels + levenshtein_fsa.__dict__["_properties"] = None + levenshtein_fsa, arc_map = k2.add_epsilon_self_loops(levenshtein_fsa, ret_arc_map=True) + scores = levenshtein_fsa.scores.clone() + scores[arc_map == -1] = ins_del_score + levenshtein_fsa.scores = scores + levenshtein_fsa.__dict__["_properties"] = None + levenshtein_fsa = k2.arc_sort(levenshtein_fsa) + return levenshtein_fsa diff --git a/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py b/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py index 5ed504fd9c45..0beab5f54cb1 100644 --- a/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py @@ -19,6 +19,8 @@ import torch +from nemo.collections.asr.parts.k2.classes import GraphIntersectDenseConfig +from nemo.collections.asr.parts.submodules.wfst_decoder import RivaDecoderConfig from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.core.classes import Typing, typecheck @@ -29,7 +31,8 @@ def pack_hypotheses( - hypotheses: List[rnnt_utils.NBestHypotheses], logitlen: torch.Tensor, + hypotheses: List[rnnt_utils.NBestHypotheses], + logitlen: torch.Tensor, ) -> List[rnnt_utils.NBestHypotheses]: if logitlen is not None: @@ -51,6 +54,39 @@ def pack_hypotheses( return hypotheses +def pack_wfst_hypotheses( + hypotheses: List['WfstNbestHypothesis'], + logits: torch.Tensor, + logitlen: torch.Tensor, +) -> List[rnnt_utils.NBestHypotheses]: + + logitlen_cpu = logitlen.to('cpu') + + new_hypotheses = [] + for idx, nbest_hyp in enumerate(hypotheses): # type: WfstNbestHypothesis + new_hyp = [] + y_sequence = logits[idx, : logitlen[idx]].to('cpu') + length = logitlen_cpu[idx] + for candidate_idx, cand in enumerate(nbest_hyp): + cand_hyp = rnnt_utils.Hypothesis( + y_sequence=[], + score=cand.score, + text=" ".join(cand.words), + timestep=list(cand.timesteps), + alignments=list(cand.alignment), + ) + cand_hyp.y_sequence = y_sequence + + if logitlen is not None: + cand_hyp.length = length + + new_hyp.append(cand_hyp) + + new_hypotheses.append(rnnt_utils.NBestHypotheses(new_hyp)) + + return new_hypotheses + + def _states_to_device(dec_state, device='cpu'): if torch.is_tensor(dec_state): dec_state = dec_state.to(device) @@ -74,8 +110,7 @@ class AbstractBeamCTCInfer(Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "decoder_output": NeuralType(('B', 'T', 'D'), LogprobsType()), "decoder_lengths": NeuralType(tuple('B'), LengthsType()), @@ -83,8 +118,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__(self, blank_id: int, beam_size: int): @@ -147,7 +181,9 @@ def set_tokenizer(self, tokenizer: TokenizerSpec): @typecheck() def forward( - self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + self, + decoder_output: torch.Tensor, + decoder_lengths: torch.Tensor, ) -> Tuple[List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]]: """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -246,7 +282,9 @@ def __init__( @typecheck() def forward( - self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + self, + decoder_output: torch.Tensor, + decoder_lengths: torch.Tensor, ) -> Tuple[List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]]: """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -568,6 +606,276 @@ def set_decoding_type(self, decoding_type: str): self.token_offset = DEFAULT_TOKEN_OFFSET +class WfstCTCInfer(AbstractBeamCTCInfer): + """A WFST-based beam CTC decoder. + + Provides a common abstraction for sample level and batch level beam decoding. + + Args: + TBD + + """ + + def __init__( + self, + blank_id: int, + beam_size: int, + search_type: str = "riva", # 'riva', 'k2' + return_best_hypothesis: bool = True, + preserve_alignments: bool = False, + compute_timestamps: bool = False, + decoding_mode: str = 'nbest', # 'nbest', 'mbr' ('mbr' works only for search_type == 'riva' and beam_size == 1) + open_vocabulary_decoding: bool = False, + beam_width: float = 10.0, + lm_weight: float = 1.0, + device: str = "cuda", + arpa_lm_path: str = None, + wfst_lm_path: str = None, + riva_decoding_cfg: Optional['RivaDecoderConfig'] = None, + k2_decoding_cfg: Optional['GraphIntersectDenseConfig'] = None, + ): + super().__init__(blank_id=blank_id, beam_size=beam_size) + + self.search_type = search_type + self.return_best_hypothesis = return_best_hypothesis + self.preserve_alignments = preserve_alignments + self.compute_timestamps = compute_timestamps + + self.decoding_algorithm = None + if search_type in ("default", "riva"): + self.decoding_algorithm = self._riva_decoding + elif search_type == "k2": + self.decoding_algorithm = self._k2_decoding + + # Log the WFST search_type + logging.info(f"WFST beam search search_type: {search_type}") + self.search_type = search_type + + if beam_size > 1 and decoding_mode != 'nbest': + logging.warning( + f"`beam_size` > 1 is supported only for `decoding_mode` == `nbest`\n" + f"(provided: `{decoding_mode}`).\n" + f"`beam_size` rewritten as 1" + ) + self.beam_size = 1 + self.decoding_mode = decoding_mode + + self.open_vocabulary_decoding = open_vocabulary_decoding + self._tokenword_disambig_id = -1 + self.beam_width = beam_width + self.lm_weight = lm_weight + self.device = device + + # Default beam search args + self.arpa_lm_path = arpa_lm_path + self.wfst_lm_path = wfst_lm_path + + self.riva_decoding_cfg = riva_decoding_cfg + self.k2_decoding_cfg = k2_decoding_cfg + + # Default beam search scorer functions + self.riva_decoder = None + self.k2_decoder = None + + @typecheck() + def forward( + self, + decoder_output: torch.Tensor, + decoder_lengths: torch.Tensor, + ) -> Tuple[List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]]: + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + if self.vocab is None: + raise RuntimeError("Please set the vocabulary with `set_vocabulary()` before calling this function.") + + if self.decoding_type != 'subword': + raise ValueError( + f"`decoding_type` other than `subword` is not supported. Provided: `{self.decoding_type}`" + ) + elif self.tokenizer is None: + raise ValueError("Tokenizer must be provided for subword decoding. Use set_tokenizer().") + if self.decoding_algorithm is None: + raise NotImplementedError( + f"The decoding search_type ({self.search_type}) supplied is not supported!\n" + f"Please use one of : (default, riva, k2)" + ) + + with torch.no_grad(), torch.inference_mode(): + # Process each sequence independently + prediction_tensor = decoder_output + + if prediction_tensor.ndim != 3: + raise ValueError( + f"`decoder_output` must be a tensor of shape [B, T, V] (log probs, float). " + f"Provided shape = {prediction_tensor.shape}" + ) + + hypotheses = self.decoding_algorithm(prediction_tensor, decoder_lengths) + + # Pack results into Hypotheses + packed_result = pack_wfst_hypotheses(hypotheses, prediction_tensor, decoder_lengths) + + # Pack the result + if self.return_best_hypothesis and isinstance(packed_result[0], rnnt_utils.NBestHypotheses): + packed_result = [res.n_best_hypotheses[0] for res in packed_result] # type: Hypothesis + + return (packed_result,) + + def _prepare_decoding_lm_wfst(self) -> Union[str, 'kaldifst.StdFst', 'k2.Fsa']: + """TBD""" + arpa_lm_path_exists = self.arpa_lm_path is not None and os.path.exists(self.arpa_lm_path) + wfst_lm_path_exists = self.wfst_lm_path is not None and os.path.exists(self.wfst_lm_path) + lm_fst = None + if wfst_lm_path_exists: + if self.search_type == "riva" and not self.wfst_lm_path.endswith(".fst"): + raise ValueError( + f"Search type `riva` expects WFSTs in the `.fst` format. Provided: `{self.wfst_lm_path}`" + ) + if self.search_type == "k2" and not self.wfst_lm_path.endswith(".pt"): + raise ValueError( + f"Search type `k2` expects WFSTs in the `.pt` format. Provided: `{self.wfst_lm_path}`" + ) + if arpa_lm_path_exists: + logging.warning( + "Both `arpa_lm_path` and `wfst_lm_path` are provided and not empty. The latter will be used." + ) + lm_fst = self.wfst_lm_path + elif not arpa_lm_path_exists: + raise FileNotFoundError( + f"Arpa LM file not found at `{self.arpa_lm_path}` and WFST LM is not found at `{self.wfst_lm_path}`.\n" + f"Please set a valid path in the decoding config for at least one of those." + ) + else: + logging.warning( + f"Since WFST LM is not found at `{self.wfst_lm_path}`, " + f"it will be made from the Arpa LM at `{self.arpa_lm_path}`.\n" + f"This procedure will take some time." + ) + if self.wfst_lm_path is not None: + logging.info(f"WFST LM will be buffered at `{self.wfst_lm_path}`.") + write_tlg_path = self.wfst_lm_path + else: + logging.warning("Consider providing a write-permitted `wfst_lm_path` for WFST LM buffering.") + write_tlg_path = None + ctc_topology = "default" # there is no way to indicate the need of other topologies + target = "kaldi" if self.search_type == "riva" else "k2" + + from nemo.collections.asr.parts.utils.wfst_utils import mkgraph_ctc_ov + + lm_fst, tokenword_disambig_id = mkgraph_ctc_ov( + tokenizer=self.tokenizer, + lm_path=self.arpa_lm_path, + topology_name=ctc_topology, + write_tlg_path=write_tlg_path, + open_vocabulary=self.open_vocabulary_decoding, + target=target, + ) + self._tokenword_disambig_id = tokenword_disambig_id + + return lm_fst + + @torch.no_grad() + def _riva_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNbestHypothesis']: + """ + Riva Asrlib WFST decoder Algorithm. + + Args: + x: Tensor of shape [B, T, V+1], where B is the batch size, T is the maximum sequence length, + and V is the vocabulary size. The tensor contains log-probabilities. + out_len: Tensor of shape [B], contains lengths of each sequence in the batch. + + Returns: + A list of WfstNbestHypothesis objects, one for each sequence in the batch. + """ + if self.riva_decoder is None: + lm_fst = self._prepare_decoding_lm_wfst() + if self.open_vocabulary_decoding and self._tokenword_disambig_id == -1: + # trying to extract tokenword_disambig_id from the lm_fst + if isinstance(lm_fst, str): + # use importer instead of direct import to possibly get an installation message + from nemo.collections.asr.parts.utils.wfst_utils import kaldifst_importer + + kaldifst = kaldifst_importer() + lm_fst = kaldifst.StdVectorFst.read(self.wfst_lm_path) + tokenword_disambig_id = lm_fst.output_symbols.find("#1") + if tokenword_disambig_id == -1: + raise ValueError( + "Cannot determine `tokenword_disambig_id` " + "which is required if `open_vocabulary_decoding` == True" + ) + self._tokenword_disambig_id = tokenword_disambig_id + if not self.device.startswith("cuda"): + raise ValueError(f"Riva decoder does not support non-cuda device. Provided: `{self.device}`") + + from nemo.collections.asr.parts.submodules.wfst_decoder import RivaGpuWfstDecoder + + self.riva_decoder = RivaGpuWfstDecoder( + lm_fst=lm_fst, + decoding_mode=self.decoding_mode, + beam_size=self.beam_width, + config=self.riva_decoding_cfg, + tokenword_disambig_id=self._tokenword_disambig_id, + lm_weight=self.lm_weight, + nbest_size=self.beam_size, + ) + + return self.riva_decoder.decode(x.to(device=self.device), out_len.to(device=self.device)) + + @torch.no_grad() + def _k2_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNbestHypothesis']: + """ + K2 WFST decoder Algorithm. + + Args: + x: Tensor of shape [B, T, V+1], where B is the batch size, T is the maximum sequence length, + and V is the vocabulary size. The tensor contains log-probabilities. + out_len: Tensor of shape [B], contains lengths of each sequence in the batch. + + Returns: + A list of WfstNbestHypothesis objects, one for each sequence in the batch. + """ + if self.k2_decoder is None: + lm_fst = self._prepare_decoding_lm_wfst() + if self.open_vocabulary_decoding and self._tokenword_disambig_id == -1: + if isinstance(lm_fst, str): + from nemo.collections.asr.parts.k2.utils import load_graph + + with torch.inference_mode(False): + lm_fst = load_graph(lm_fst) + try: + tokenword_disambig_id = lm_fst.aux_labels_sym.get("#1") + self._tokenword_disambig_id = tokenword_disambig_id + except KeyError: + raise ValueError( + "Cannot determine `tokenword_disambig_id` " + "which is required if `open_vocabulary_decoding` == True" + ) + + from nemo.collections.asr.parts.k2.graph_decoders import K2WfstDecoder + + self.k2_decoder = K2WfstDecoder( + lm_fst=lm_fst, + decoding_mode=self.decoding_mode, + beam_size=self.beam_width, + config=self.k2_decoding_cfg, + tokenword_disambig_id=self._tokenword_disambig_id, + lm_weight=self.lm_weight, + nbest_size=self.beam_size, + device=self.device, + ) + + return self.k2_decoder.decode(x.to(device=self.device), out_len.to(device=self.device)) + + @dataclass class PyCTCDecodeConfig: # These arguments cannot be imported from pyctcdecode (optional dependency) @@ -604,3 +912,21 @@ class BeamCTCInferConfig: flashlight_cfg: Optional[FlashlightConfig] = field(default_factory=lambda: FlashlightConfig()) pyctcdecode_cfg: Optional[PyCTCDecodeConfig] = field(default_factory=lambda: PyCTCDecodeConfig()) + + +@dataclass +class WfstCTCInferConfig: + beam_size: int + search_type: str = "riva" # 'riva', 'k2' + return_best_hypothesis: bool = True + preserve_alignments: bool = False + compute_timestamps: bool = False + decoding_mode: str = 'nbest' # 'nbest', 'mbr' ('mbr' works only for search_type == 'riva' and beam_size == 1) + open_vocabulary_decoding: bool = False + beam_width: float = 10.0 + lm_weight: float = 1.0 + device: str = "cuda" + arpa_lm_path: Optional[str] = None + wfst_lm_path: Optional[str] = None + riva_decoding_cfg: Optional['RivaDecoderConfig'] = field(default_factory=lambda: RivaDecoderConfig()) + k2_decoding_cfg: Optional['GraphIntersectDenseConfig'] = field(default_factory=lambda: GraphIntersectDenseConfig()) diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index d2bfb629293e..ec27d3dbbd22 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -213,7 +213,7 @@ def __init__(self, decoding_cfg, blank_id: int): self.batch_dim_index = self.cfg.get('batch_dim_index', 0) self.word_seperator = self.cfg.get('word_seperator', ' ') - possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight'] + possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight', 'wfst'] if self.cfg.strategy not in possible_strategies: raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}") @@ -314,6 +314,28 @@ def __init__(self, decoding_cfg, blank_id: int): self.decoding.override_fold_consecutive_value = False + elif self.cfg.strategy == 'wfst': + + self.decoding = ctc_beam_decoding.WfstCTCInfer( + blank_id=blank_id, + beam_size=self.cfg.wfst.get('beam_size', 1), + search_type=self.cfg.wfst.get('search_type', 'riva'), + return_best_hypothesis=self.cfg.wfst.get('return_best_hypothesis', True), + preserve_alignments=self.preserve_alignments, + compute_timestamps=self.compute_timestamps, + decoding_mode=self.cfg.wfst.get('decoding_mode', 'nbest'), + open_vocabulary_decoding=self.cfg.wfst.get('open_vocabulary_decoding', False), + beam_width=self.cfg.wfst.get('beam_width', 10.0), + lm_weight=self.cfg.wfst.get('lm_weight', 1.0), + device=self.cfg.wfst.get('device', 'cuda'), + arpa_lm_path=self.cfg.wfst.get('arpa_lm_path', None), + wfst_lm_path=self.cfg.wfst.get('wfst_lm_path', None), + riva_decoding_cfg=self.cfg.wfst.get('riva_decoding_cfg', None), + k2_decoding_cfg=self.cfg.wfst.get('k2_decoding_cfg', None), + ) + + self.decoding.override_fold_consecutive_value = False + else: raise ValueError( f"Incorrect decoding strategy supplied. Must be one of {possible_strategies}\n" @@ -374,48 +396,56 @@ def ctc_decoder_predictions_tensor( hypotheses_list = hypotheses_list[0] # type: List[Hypothesis] if isinstance(hypotheses_list[0], NBestHypotheses): - hypotheses = [] - all_hypotheses = [] + if self.cfg.strategy == 'wfst': + all_hypotheses = [hyp.n_best_hypotheses for hyp in hypotheses_list] + hypotheses = [hyp[0] for hyp in all_hypotheses] + else: + hypotheses = [] + all_hypotheses = [] - for nbest_hyp in hypotheses_list: # type: NBestHypotheses - n_hyps = nbest_hyp.n_best_hypotheses # Extract all hypotheses for this sample - decoded_hyps = self.decode_hypothesis( - n_hyps, fold_consecutive - ) # type: List[Union[Hypothesis, NBestHypotheses]] + for nbest_hyp in hypotheses_list: # type: NBestHypotheses + n_hyps = nbest_hyp.n_best_hypotheses # Extract all hypotheses for this sample + decoded_hyps = self.decode_hypothesis( + n_hyps, fold_consecutive + ) # type: List[Union[Hypothesis, NBestHypotheses]] - # If computing timestamps - if self.compute_timestamps is True: - timestamp_type = self.cfg.get('ctc_timestamp_type', 'all') - for hyp_idx in range(len(decoded_hyps)): - decoded_hyps[hyp_idx] = self.compute_ctc_timestamps(decoded_hyps[hyp_idx], timestamp_type) + # If computing timestamps + if self.compute_timestamps is True: + timestamp_type = self.cfg.get('ctc_timestamp_type', 'all') + for hyp_idx in range(len(decoded_hyps)): + decoded_hyps[hyp_idx] = self.compute_ctc_timestamps(decoded_hyps[hyp_idx], timestamp_type) - hypotheses.append(decoded_hyps[0]) # best hypothesis - all_hypotheses.append(decoded_hyps) + hypotheses.append(decoded_hyps[0]) # best hypothesis + all_hypotheses.append(decoded_hyps) if return_hypotheses: return hypotheses, all_hypotheses best_hyp_text = [h.text for h in hypotheses] + # alaptev: The line below might contain a bug. Do we really want all_hyp_text to be flat? all_hyp_text = [h.text for hh in all_hypotheses for h in hh] return best_hyp_text, all_hyp_text else: - hypotheses = self.decode_hypothesis( - hypotheses_list, fold_consecutive - ) # type: List[Union[Hypothesis, NBestHypotheses]] + if self.cfg.strategy == 'wfst': + hypotheses = hypotheses_list + else: + hypotheses = self.decode_hypothesis( + hypotheses_list, fold_consecutive + ) # type: List[Union[Hypothesis, NBestHypotheses]] - # If computing timestamps - if self.compute_timestamps is True: - # greedy decoding, can get high-level confidence scores - if return_hypotheses and (self.preserve_word_confidence or self.preserve_token_confidence): - hypotheses = self.compute_confidence(hypotheses) - else: - # remove unused token_repetitions from Hypothesis.text - for hyp in hypotheses: - hyp.text = hyp.text[:2] - timestamp_type = self.cfg.get('ctc_timestamp_type', 'all') - for hyp_idx in range(len(hypotheses)): - hypotheses[hyp_idx] = self.compute_ctc_timestamps(hypotheses[hyp_idx], timestamp_type) + # If computing timestamps + if self.compute_timestamps is True: + # greedy decoding, can get high-level confidence scores + if return_hypotheses and (self.preserve_word_confidence or self.preserve_token_confidence): + hypotheses = self.compute_confidence(hypotheses) + else: + # remove unused token_repetitions from Hypothesis.text + for hyp in hypotheses: + hyp.text = hyp.text[:2] + timestamp_type = self.cfg.get('ctc_timestamp_type', 'all') + for hyp_idx in range(len(hypotheses)): + hypotheses[hyp_idx] = self.compute_ctc_timestamps(hypotheses[hyp_idx], timestamp_type) if return_hypotheses: return hypotheses, None @@ -1324,6 +1354,11 @@ class CTCDecodingConfig: default_factory=lambda: ctc_beam_decoding.BeamCTCInferConfig(beam_size=4) ) + # wfst decoding config + wfst: ctc_beam_decoding.WfstCTCInferConfig = field( + default_factory=lambda: ctc_beam_decoding.WfstCTCInferConfig(beam_size=4) + ) + # confidence config confidence_cfg: ConfidenceConfig = field(default_factory=lambda: ConfidenceConfig()) diff --git a/nemo/collections/asr/parts/submodules/wfst_decoder.py b/nemo/collections/asr/parts/submodules/wfst_decoder.py new file mode 100644 index 000000000000..373e041da1be --- /dev/null +++ b/nemo/collections/asr/parts/submodules/wfst_decoder.py @@ -0,0 +1,791 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +import tempfile +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union + +import torch +from jiwer import wer as word_error_rate +from omegaconf import DictConfig + +from nemo.collections.asr.parts.utils.wfst_utils import TW_BREAK, kaldifst_importer + +RIVA_DECODER_INSTALLATION_MESSAGE = ( + "riva decoder is not installed or is installed incorrectly.\n" + "please run `bash scripts/installers/install_riva_decoder.sh` or `pip install riva-asrlib-decoder` to install." +) + + +def riva_decoder_importer(): + """Import helper function that returns Riva asrlib decoder package or raises ImportError exception.""" + try: + import riva.asrlib.decoder.python_decoder as riva_decoder + except (ImportError, ModuleNotFoundError): + raise ImportError(RIVA_DECODER_INSTALLATION_MESSAGE) + return riva_decoder + + +def _riva_config_to_dict(conf: Any) -> Dict[str, Any]: + """ + Helper function for parsing Riva configs (namely BatchedMappedDecoderCudaConfig) into a dictionary. + + Args: + conf: + Inner Riva config. + + Returns: + Dictionary corresponding to the Riva config. + """ + result = {} + for name in conf.__dir__(): + if not name.startswith("__"): + attribute = getattr(conf, name) + result[name] = ( + attribute if attribute.__class__.__module__ == 'builtins' else _riva_config_to_dict(attribute) + ) + return result + + +def _fill_inner_riva_config_(riva_conf, nemo_conf): + """ + Helper function for filling Riva configs (namely BatchedMappedDecoderCudaConfig) + according to the corresponding NeMo config. + + Note: in-place for the first argument. + + Args: + riva_conf: + Inner Riva config. + + nemo_conf: + Corresponding NeMo config. + """ + for nemo_k, nemo_v in nemo_conf.items(): + if isinstance(nemo_v, DictConfig): + _fill_inner_riva_config_(getattr(riva_conf, nemo_k), nemo_v) + else: + setattr(riva_conf, nemo_k, nemo_v) + + +class RivaDecoderConfig(DictConfig): + """ + NeMo config for the RivaGpuWfstDecoder. + """ + + def __init__(self): + try: + riva_decoder = riva_decoder_importer() + + config = riva_decoder.BatchedMappedDecoderCudaConfig() + config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 + config.n_input_per_chunk = 50 + config.online_opts.decoder_opts.default_beam = 20.0 + config.online_opts.decoder_opts.max_active = 10000 + config.online_opts.determinize_lattice = True + config.online_opts.max_batch_size = 800 + config.online_opts.num_channels = 800 + config.online_opts.frame_shift_seconds = 1 # not actual frame shift + config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0 + + content = _riva_config_to_dict(config) + except ImportError: + content = {} + super().__init__(content) + + +class WfstNbestUnit(NamedTuple): + """ + Container for a single RivaGpuWfstDecoder n-best hypothesis. + """ + + words: Tuple[str] + timesteps: Tuple[int] + alignment: Tuple[int] + score: float + + +class WfstNbestHypothesis: + """ + Container for the RivaGpuWfstDecoder n-best results represented as a list of WfstNbestUnit objects. + """ + + def __init__(self, raw_hypotheses: Tuple[Tuple[Tuple[str], Tuple[int], Tuple[int], float]]): + for i, rh in enumerate(raw_hypotheses): + assert isinstance(rh[0], tuple), f"{rh[0]}" + assert isinstance(rh[1], tuple), f"{rh[1]}, {rh[0]}" + assert isinstance(rh[2], tuple), f"{rh[2]}" + assert isinstance(rh[3], float), f"{rh[3]}" + assert len(rh[0]) == len(rh[1]) or len(rh[1]) == 0, "words do not match timesteps" + + self._hypotheses = sorted([WfstNbestUnit(*rh) for rh in raw_hypotheses], key=lambda hyp: hyp.score) + self._shape0 = len(self._hypotheses) + self._shape1 = [len(h.words) for h in self._hypotheses] + self._has_timesteps = len(self._hypotheses[0].timesteps) > 0 + self._has_alignment = len(self._hypotheses[0].alignment) > 0 + + def __iter__(self): + yield from self._hypotheses + + def __getitem__(self, index): + return self._hypotheses[index] + + def __len__(self): + return self.shape0 + + def replace_unit_( + self, index: int, new_unit: Union[WfstNbestUnit, Tuple[Tuple[str], Tuple[int], Tuple[int], float]] + ): + """ + Replaces a WfstNbestUnit by index. + + Note: in-place operation. + + Args: + index: + Index of the unit to be replaced. + + new_unit: + Replacement unit. + """ + assert 0 <= index < self.shape0 + assert ( + self.has_timesteps + and len(new_unit[0]) == len(new_unit[1]) + or not self.has_timesteps + and len(new_unit[1]) == 0 + ) + assert ( + index == 0 + and (len(self._hypotheses) == 1 or new_unit[3] <= self._hypotheses[index + 1].score) + or index == self.shape0 - 1 + and self._hypotheses[index - 1].score <= new_unit[3] + or self._hypotheses[index - 1].score <= new_unit[3] <= self._hypotheses[index + 1].score + ) + + if not isinstance(new_unit, WfstNbestUnit): + new_unit = WfstNbestUnit(*new_unit) + self._hypotheses[index] = new_unit + self._shape1[index] = len(new_unit.words) + + @property + def shape0(self): + return self._shape0 + + @property + def shape1(self): + return self._shape1 + + @property + def has_timesteps(self): + return self._has_timesteps + + @property + def has_alignment(self): + return self._has_alignment + + +def collapse_tokenword_hypotheses( + hypotheses: List[WfstNbestHypothesis], tokenword_disambig_str: str +) -> List[WfstNbestHypothesis]: + """ + Searches for tokenwords in the input hypotheses and collapses them into words. + + Args: + hypotheses: + List of input WfstNbestHypothesis. + + tokenword_disambig_str: + Tokenword disambiguation symbol (e.g. `#1`). + + Returns: + List of WfstNbestHypothesis. + """ + new_hypotheses = copy.deepcopy(hypotheses) + for hyp in new_hypotheses: + for k, h_unit in enumerate(hyp): + twds_list = [] + for i, word in enumerate(h_unit.words): + if word == tokenword_disambig_str: + twds_list.append(i) + if len(twds_list) > 0: + # a rare case when the recognition stopped before completing the tokenword + old_words = list(h_unit.words) + old_timesteps = list(h_unit.timesteps) + words_len = len(old_words) + if len(twds_list) % 2 == 1: + twds_list.append(words_len) + new_words, new_timesteps = [], [] + j_prev = 0 + for i, j in zip(twds_list[::2], twds_list[1::2]): + new_words += old_words[j_prev:i] + # drop tokenword disambig -> remove token disanbig suffix -> remove word begin mark + new_word = "".join(old_words[i + 1 : j]).replace(f"{TW_BREAK}{tokenword_disambig_str}", "")[1:] + new_words.append(new_word) + new_timesteps += old_timesteps[j_prev:i] + [ + old_timesteps[i], + ] + j_prev = j + 1 + if j_prev < words_len: + new_words += old_words[j_prev:words_len] + new_timesteps += old_timesteps[j_prev:words_len] + hyp.replace_unit_(k, (tuple(new_words), tuple(new_timesteps), h_unit.alignment, h_unit.score)) + return new_hypotheses + + +class AbstractWFSTDecoder(ABC): + """ + Used for performing WFST decoding of the logprobs. + + Args: + lm_fst: + Language model WFST. + + decoding_mode: + Decoding mode. E.g. `nbest`. + + beam_size: + Beam width (float) for the WFST decoding. + + config: + Decoder config. + + tokenword_disambig_id: + Tokenword disambiguation index. Set to -1 to disable the tokenword mode. + + lm_weight: + Language model weight in decoding. + """ + + def __init__( + self, + lm_fst: Any, + decoding_mode: str, + beam_size: float, + config: Optional[Any], + tokenword_disambig_id: int = -1, + lm_weight: float = 1.0, + ): + self._lm_fst = lm_fst + self._beam_size = beam_size + self._tokenword_disambig_id = tokenword_disambig_id + self._open_vocabulary_decoding = self._tokenword_disambig_id >= 0 + self._lm_weight = lm_weight + self._id2word, self._word2id = None, None + self._id2token, self._token2id = None, None + self._decoding_mode, self._config, self._decoder = None, None, None + + self._set_decoding_mode(decoding_mode) + self._set_decoder_config(config) + self._init_decoder() + + @abstractmethod + def _set_decoder_config(self, config: Optional[Any] = None): + pass + + @abstractmethod + def _set_decoding_mode(self, decoding_mode: str): + pass + + @abstractmethod + def _init_decoder(self): + pass + + @property + def decoding_mode(self): + return self._decoding_mode + + @decoding_mode.setter + def decoding_mode(self, value: str): + self._decoding_mode_setter(value) + + @abstractmethod + def _decoding_mode_setter(self, value: str): + pass + + @property + def beam_size(self): + return self._beam_size + + @beam_size.setter + def beam_size(self, value: float): + self._beam_size_setter(value) + + @abstractmethod + def _beam_size_setter(self, value: float): + pass + + @property + def lm_weight(self): + return self._lm_weight + + @lm_weight.setter + def lm_weight(self, value: float): + self._lm_weight_setter(value) + + @abstractmethod + def _lm_weight_setter(self, value: float): + pass + + @property + def tokenword_disambig_id(self): + return self._tokenword_disambig_id + + @property + def open_vocabulary_decoding(self): + return self._open_vocabulary_decoding + + @abstractmethod + def decode(self, log_probs: torch.Tensor, log_probs_length: torch.Tensor) -> List[Any]: + """ + Decodes logprobs into recognition hypotheses. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + Returns: + List of recognition hypotheses. + """ + pass + + @abstractmethod + def _post_decode(self, hypotheses: List[Any]) -> List[Any]: + """ + Does various post-processing of the recognition hypotheses. + + Args: + hypotheses: + List of recognition hypotheses. + + Returns: + List of processed recognition hypotheses. + """ + pass + + @abstractmethod + def calibrate_lm_weight( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, reference_texts: List[str] + ) -> Tuple[float, float]: + """ + Calibrates LM weight to achieve the best WER for given logprob-text pairs. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + reference_texts: + List of reference word sequences. + + Returns: + Pair of (best_lm_weight, best_wer). + """ + pass + + @abstractmethod + def calculate_oracle_wer( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, reference_texts: List[str] + ) -> Tuple[float, List[float]]: + """ + Calculates the oracle (the best possible WER for given logprob-text pairs. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + reference_texts: + List of reference word sequences. + + Returns: + Pair of (oracle_wer, oracle_wer_per_utterance). + """ + pass + + +class RivaGpuWfstDecoder(AbstractWFSTDecoder): + """ + Used for performing WFST decoding of the logprobs with the Riva WFST decoder. + + Args: + lm_fst: + Kaldi-type language model WFST or its path. + + decoding_mode: + Decoding mode. Choices: `nbest`, `mbr`, `lattice`. + + beam_size: + Beam width (float) for the WFST decoding. + + config: + Riva Decoder config. + + tokenword_disambig_id: + Tokenword disambiguation index. Set to -1 to disable the tokenword mode. + + lm_weight: + Language model weight in decoding. + + nbest_size: + N-best size for decoding_mode == `nbest` + """ + + def __init__( + self, + lm_fst: Union['kaldifst.StdFst', Path, str], + decoding_mode: str = 'mbr', + beam_size: float = 10.0, + config: Optional['RivaDecoderConfig'] = None, + tokenword_disambig_id: int = -1, + lm_weight: float = 1.0, + nbest_size: int = 1, + ): + self._nbest_size = nbest_size + self._load_word_lattice = None + super().__init__(lm_fst, decoding_mode, beam_size, config, tokenword_disambig_id, lm_weight) + + def _set_decoder_config(self, config: Optional['RivaDecoderConfig'] = None): + if config is None or len(config) == 0: + config = RivaDecoderConfig() + if not hasattr(config, "online_opts"): + # most likely empty config + # call importer to raise the exception + installation message + riva_decoder_importer() + # just in case + raise RuntimeError("Unexpected config error. Please debug manually.") + config.online_opts.decoder_opts.lattice_beam = self._beam_size + config.online_opts.lattice_postprocessor_opts.lm_scale = ( + self._lm_weight * config.online_opts.lattice_postprocessor_opts.acoustic_scale + ) + config.online_opts.lattice_postprocessor_opts.nbest = self._nbest_size + self._config = config + + def _init_decoder(self): + + # use importers instead of direct import to possibly get an installation message + kaldifst = kaldifst_importer() + riva_decoder = riva_decoder_importer() + + from nemo.collections.asr.parts.utils.wfst_utils import load_word_lattice + + self._load_word_lattice = load_word_lattice + # BatchedMappedDecoderCuda supports filepaths only + # TODO: fix when possible + lm_fst = self._lm_fst + tmp_fst = None + tmp_fst_file = None + if isinstance(lm_fst, (Path, str)): + # We only read lm_fst to extract words.txt and num_tokens_with_blank + tmp_fst = kaldifst.StdVectorFst.read(lm_fst) + elif isinstance(lm_fst, (kaldifst.StdVectorFst, kaldifst.StdConstFst)): + tmp_fst = lm_fst + tmp_fst_file = tempfile.NamedTemporaryFile(mode='w+t') + tmp_fst.write(tmp_fst_file.name) + lm_fst = tmp_fst_file.name + else: + raise ValueError(f"Unsupported lm_fst type: {type(lm_fst)}") + + # we assume that lm_fst has at least one disambig after real tokens + num_tokens_with_blank = tmp_fst.input_symbols.find('#0') - 1 + if self._id2word is None: + self._id2word = { + int(line.split("\t")[1]): line.split("\t")[0] + for line in str(tmp_fst.output_symbols).strip().split("\n") + } + word2id = self._id2word.__class__(map(reversed, self._id2word.items())) + word_unk_id = word2id[""] + self._word2id = defaultdict(lambda: word_unk_id) + for k, v in word2id.items(): + self._word2id[k] = v + if self._id2token is None: + self._id2token = { + int(line.split("\t")[1]): line.split("\t")[0] + for line in str(tmp_fst.input_symbols).strip().split("\n") + } + token2id = self._id2token.__class__(map(reversed, self._id2token.items())) + token_unk_id = token2id[""] + self._token2id = defaultdict(lambda: token_unk_id) + for k, v in token2id.items(): + self._token2id[k] = v + with tempfile.NamedTemporaryFile(mode='w+t') as words_tmp: + tmp_fst.output_symbols.write_text(words_tmp.name) + config = riva_decoder.BatchedMappedDecoderCudaConfig() + _fill_inner_riva_config_(config, self._config) + self._decoder = riva_decoder.BatchedMappedDecoderCuda( + config, lm_fst, words_tmp.name, num_tokens_with_blank + ) + if tmp_fst_file: + tmp_fst_file.close() + + def _set_decoding_mode(self, decoding_mode: str): + if decoding_mode == 'nbest': + self._decode = self._decode_nbest + elif decoding_mode == 'mbr': + self._decode = self._decode_mbr + elif decoding_mode == 'lattice': + self._decode = self._decode_lattice + else: + raise ValueError(f"Unsupported mode: {decoding_mode}") + self._decoding_mode = decoding_mode + + def _beam_size_setter(self, value: float): + if self._beam_size != value: + self._release_gpu_memory() + self._config.online_opts.decoder_opts.lattice_beam = value + self._init_decoder() + self._beam_size = value + + def _lm_weight_setter(self, value: float): + if self._lm_weight != value: + self._release_gpu_memory() + self._config.online_opts.lattice_postprocessor_opts.lm_scale = ( + value * self._config.online_opts.lattice_postprocessor_opts.acoustic_scale + ) + self._init_decoder() + self._lm_weight = value + + def _decoding_mode_setter(self, value: str): + if self._decoding_mode != value: + self._set_decoding_mode(value) + + @property + def nbest_size(self): + return self._nbest_size + + @nbest_size.setter + def nbest_size(self, value: float): + self._nbest_size_setter(value) + + def _nbest_size_setter(self, value: float): + if self._nbest_size != value: + self._release_gpu_memory() + self._config.online_opts.lattice_postprocessor_opts.nbest = value + self._init_decoder() + self._nbest_size = value + + def _decode_nbest( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor + ) -> List[WfstNbestHypothesis]: # words, timesteps, alignment, score + """ + Decodes logprobs into recognition hypotheses via the N-best decoding decoding. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + Returns: + List of WfstNbestHypothesis with empty alignment and trivial score. + """ + hypotheses_nbest = self._decoder.decode_nbest(log_probs, log_probs_length) + hypotheses = [] + for nh in hypotheses_nbest: + nbest_container = [] + for h in nh: + words, timesteps = [], [] + for w, t in zip(h.words, h.word_start_times_seconds): + if w != 0: + words.append(self._id2word[w]) + timesteps.append(int(t)) + alignment = [ilabel - 1 for ilabel in h.ilabels] + score = h.score + nbest_container.append(tuple([tuple(words), tuple(timesteps), tuple(alignment), score])) + hypotheses.append(WfstNbestHypothesis(tuple(nbest_container))) + return hypotheses + + def _decode_mbr(self, log_probs: torch.Tensor, log_probs_length: torch.Tensor) -> List[WfstNbestHypothesis]: + """ + Decodes logprobs into recognition hypotheses via the Minimum Bayes Risk (MBR) decoding. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + Returns: + List of WfstNbestHypothesis with empty alignment and trivial score. + """ + hypotheses_mbr = self._decoder.decode_mbr(log_probs, log_probs_length) + hypotheses = [] + for h in hypotheses_mbr: + words, timesteps = [], [] + for e in h: + words.append(e[0]) + timesteps.append(int(e[1])) + hypotheses.append(WfstNbestHypothesis(tuple([tuple([tuple(words), tuple(timesteps), tuple(), 0.0])]))) + return hypotheses + + def _decode_lattice(self, log_probs: torch.Tensor, log_probs_length: torch.Tensor) -> List['KaldiWordLattice']: + """ + Decodes logprobs into kaldi-type lattices. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + Returns: + List of KaldiWordLattice. + """ + with tempfile.NamedTemporaryFile() as tmp_lat: + tmp_lat_name = f"{tmp_lat.name}.lats" + self._decoder.decode_write_lattice( + log_probs, log_probs_length, [str(i) for i in range(len(log_probs))], f"ark,t:{tmp_lat_name}" + ) + hypotheses_lattice = self._load_word_lattice( + tmp_lat_name, self._id2word, self._id2word + ) # input and output token ids are the same + hypotheses = [hypotheses_lattice[str(i)] for i in range(len(log_probs))] + return hypotheses + + def decode( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor + ) -> Union[List[WfstNbestHypothesis], List['KaldiWordLattice']]: + """ + Decodes logprobs into recognition hypotheses. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + Returns: + List of recognition hypotheses. + """ + log_probs = log_probs.contiguous() + log_probs_length = log_probs_length.to(torch.long).to('cpu').contiguous() + hypotheses = self._decode(log_probs, log_probs_length) + hypotheses = self._post_decode(hypotheses) + return hypotheses + + def _post_decode( + self, hypotheses: Union[List[WfstNbestHypothesis], List['KaldiWordLattice']] + ) -> Union[List[WfstNbestHypothesis], List['KaldiWordLattice']]: + """ + Does various post-processing of the recognition hypotheses. + + Args: + hypotheses: + List of recognition hypotheses. + + Returns: + List of processed recognition hypotheses. + """ + if self._open_vocabulary_decoding and self._decoding_mode in ('nbest', 'mbr'): + return collapse_tokenword_hypotheses(hypotheses, self._id2word[self._tokenword_disambig_id]) + else: + return hypotheses + + def calibrate_lm_weight( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, reference_texts: List[str] + ) -> Tuple[float, float]: + """ + Calibrates LM weight to achieve the best WER for given logprob-text pairs. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + reference_texts: + List of reference word sequences. + + Returns: + Pair of (best_lm_weight, best_wer). + """ + assert len(log_probs) == len(reference_texts) + decoding_mode_backup = self.decoding_mode + lm_weight_backup = self.lm_weight + self.decoding_mode = "mbr" + best_lm_weight, best_wer = -1.0, float('inf') + for lm_weight in range(1, 21): # enough for most cases + self.lm_weight = lm_weight / 10 + hypotheses = self.decode(log_probs, log_probs_length) + wer = word_error_rate([" ".join(h[0].words) for h in hypotheses], reference_texts) + print(lm_weight, wer) + if wer < best_wer: + best_lm_weight, best_wer = self.lm_weight, wer + self.decoding_mode = decoding_mode_backup + self.lm_weight = lm_weight_backup + return best_lm_weight, best_wer + + def calculate_oracle_wer( + self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, reference_texts: List[str] + ) -> Tuple[float, List[float]]: + """ + Calculates the oracle (the best possible WER for given logprob-text pairs. + + Args: + log_probs: + A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary]. + + log_probs_length: + A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements. + + reference_texts: + List of reference word sequences. + + Returns: + Pair of (oracle_wer, oracle_wer_per_utterance). + """ + if self._open_vocabulary_decoding: + raise NotImplementedError + assert len(log_probs) == len(reference_texts) + decoding_mode_backup = self.decoding_mode + self.decoding_mode = "lattice" + lattices = self.decode(log_probs, log_probs_length) + scores, counts, wer_per_utt = [], [], [] + for lattice, text in zip(lattices, reference_texts): + word_ids = [self._word2id[w] for w in text.strip().split()] + counts.append(len(word_ids) if word_ids else 1) + scores.append(lattice.edit_distance(word_ids)) + wer_per_utt.append(scores[-1] / counts[-1]) + self.decoding_mode = decoding_mode_backup + return sum(scores) / sum(counts), wer_per_utt + + def _release_gpu_memory(self): + """ + Forces freeing of GPU memory by deleting the Riva decoder object. + """ + try: + del self._decoder + except Exception: + # apparently self._decoder was previously deleted, do nothing + pass + gc.collect() + + def __del__(self): + self._release_gpu_memory() diff --git a/nemo/collections/asr/parts/utils/wfst_utils.py b/nemo/collections/asr/parts/utils/wfst_utils.py new file mode 100644 index 000000000000..31f394fb60ac --- /dev/null +++ b/nemo/collections/asr/parts/utils/wfst_utils.py @@ -0,0 +1,1478 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import tempfile +from abc import ABC, abstractmethod, abstractproperty +from collections import defaultdict, namedtuple +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union + +from nemo.utils import logging + + +TW_BREAK = "‡" + + +try: + import kaldifst + + # check that kaldifst package is not empty + # Note: pytorch_lightning.utilities.imports.package_available may not help here + kaldifst.StdVectorFst() + _KALDIFST_AVAILABLE = True +except (ImportError, ModuleNotFoundError, AttributeError): + _KALDIFST_AVAILABLE = False + + +try: + import graphviz + + _GRAPHVIZ_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + _GRAPHVIZ_AVAILABLE = False + + +try: + import kaldilm + + _KALDILM_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + _KALDILM_AVAILABLE = False + + +KALDIFST_INSTALLATION_MESSAGE = ( + "kaldifst is not installed or is installed incorrectly.\n" + "please run `pip install kaldifst` or `bash scripts/installers/install_riva_decoder.sh` to install." +) + + +GRAPHVIZ_INSTALLATION_MESSAGE = ( + "graphviz is not installed.\n" "please run `bash scripts/installers/install_graphviz.sh` to install." +) + + +KALDILM_INSTALLATION_MESSAGE = ( + "kaldilm is not installed.\n" + "please run `pip install kaldilm` or `bash scripts/installers/install_riva_decoder.sh` to install." +) + + +def _kaldifst_maybe_raise(): + if _KALDIFST_AVAILABLE is False: + raise ImportError(KALDIFST_INSTALLATION_MESSAGE) + + +def kaldifst_importer(): + """Import helper function that returns kaldifst package or raises ImportError exception.""" + _kaldifst_maybe_raise() + return kaldifst + + +def _graphviz_maybe_raise(): + if _GRAPHVIZ_AVAILABLE is False: + raise ImportError(GRAPHVIZ_INSTALLATION_MESSAGE) + + +def graphviz_importer(): + """Import helper function that returns graphviz package or raises ImportError exception.""" + _graphviz_maybe_raise() + return graphviz + + +def _kaldilm_maybe_raise(): + if _KALDILM_AVAILABLE is False: + raise ImportError(KALDILM_INSTALLATION_MESSAGE) + + +def kaldilm_importer(): + """Import helper function that returns kaldifst package or raises ImportError exception.""" + _kaldilm_maybe_raise() + return kaldilm + + +@dataclass +class LexiconUnit: + """A dataclass encapsulating the name of the language unit (e.g. wordpiece) and its mark (e.g. word begin).""" + + name: str + mark: str = "" + + +class Lexicon: + def __init__( + self, + wordid2tokenid: Dict[int, List[List[int]]], + id2word: Union[Dict[int, str], Dict[int, LexiconUnit]], + id2token: Union[Dict[int, str], Dict[int, LexiconUnit]], + disambig_pattern: str = re.compile(r"^#\d+$"), + ): + """ + Lexicon class which contains word-to-token-sequence, word-to-id, and token-to-id mappings. + + Args: + wordid2tokenid: + Lexicon. + Mapping from word_id to token1_id token2_id ... tokenN_id. + + id2word: + Word index. + Mapping from word_id to word_str. + + id2token: + Token index. + Mapping from token_id to token_str. + + disambig_pattern: + Pattern for disambiguation symbols. + """ + is_id2token_str = not isinstance(list(id2token.values())[0], LexiconUnit) + self.id2token = {k: LexiconUnit(v) for k, v in id2token.items()} if is_id2token_str else id2token + self.token2id = {v.name: k for k, v in self.id2token.items()} + is_id2word_str = not isinstance(list(id2word.values())[0], LexiconUnit) + self.id2word = {k: LexiconUnit(v) for k, v in id2word.items()} if is_id2word_str else id2word + self.word2id = {v.name: k for k, v in self.id2word.items()} + self.wordid2tokenid = wordid2tokenid + word2tokens = defaultdict(list) + for k, v in self.wordid2tokenid.items(): + word2tokens[self.id2word[k].name] += [[self.id2token[i].name for i in vp] for vp in v] + self.word2tokens = word2tokens + self.disambig_pattern = disambig_pattern + + max_disambig_id = -1 + num_disambigs = 0 + self.has_epsilon = False + self._default_disambig_mark = "disambig" + self._default_epsilon_mark = "epsilon" + self._default_epsilon_name = "" + for i, s in self.id2token.items(): + if self.disambig_pattern.match(s.name): + if is_id2token_str or not s.mark.startswith(self._default_disambig_mark): + s.mark = self._default_disambig_mark + if i > max_disambig_id: + max_disambig_id = i + num_disambigs += 1 + if s.name == self._default_epsilon_name or s.mark == self._default_epsilon_mark: + assert i == 0 + self.has_epsilon = True + self.max_disambig_id = max_disambig_id + self.num_disambigs = num_disambigs + + if is_id2word_str: + for i, s in self.id2word.items(): + if self.disambig_pattern.match(s.name): + s.mark = self._default_disambig_mark + elif s.name == self._default_epsilon_name: + s.mark == self._default_epsilon_mark + + def __iter__(self) -> Tuple[str, List[str]]: + for wordid, tokenid_list in self.wordid2tokenid.items(): + for tokenids in tokenid_list: + yield wordid, tokenids + + def __str__(self): + return str(self.word2tokens) + + @property + def token_ids(self) -> List[int]: + """Return a list of token IDs excluding those from + disambiguation symbols. + """ + ans = [] + for i, s in self.id2token.items(): + if not s.mark.startswith(self._default_epsilon_mark) and (not self.has_epsilon or i != 0): + ans.append(i) + ans.sort() + return ans + + +def arpa2fst(lm_path: str, attach_symbol_table: bool = True) -> 'kaldifst.StdVectorFst': + """ + Compiles an ARPA LM file into a grammar WFST (G.fst). + + Args: + lm_path: + Path to the ARPA LM file. + + attach_symbol_table: + Whether to attach the words for indices of the returned WFST. + + Returns: + Kaldi-type grammar WFST. + """ + _kaldifst_maybe_raise() + _kaldilm_maybe_raise() + + with tempfile.TemporaryDirectory() as tempdirname: + output_fst = os.path.join(tempdirname, "output.fst") + words_txt = os.path.join(tempdirname, "words.txt") + # with suppress_stdout_stderr(): + kaldilm.arpa2fst( + input_arpa=lm_path, + output_fst=output_fst, + disambig_symbol="#0", + write_symbol_table=words_txt, + ) + + G = kaldifst.StdVectorFst.read(output_fst) + + if attach_symbol_table: + osym = kaldifst.SymbolTable() + with open(words_txt, encoding="utf-8") as f: + for line in f: + w, i = line.strip().split() + osym.add_symbol(symbol=w, key=int(i)) + G.output_symbols = osym + + kaldifst.arcsort(G, sort_type="ilabel") + return G + + +def add_tokenwords_( + g_fst: 'kaldifst.StdVectorFst', + tokens: List[str], + word_weight: float = 2.0, + token_unigram_weight: float = 4.0, + token_oov: str = "", +) -> int: + """ + Adds special words representing individual tokens (tokenwords). + In-place operation. + + Args: + g_fst: + Kaldi-type grammar WFST. + Will be augmented with the tokenwords. + + tokens: + Token vocabulary. + + word_weight: + The weight of an Out Of Vocabulary (OOV) word emission. + + token_unigram_weight: + The weight of a tokenword emission. + + token_oov: + OOV token. + + Returns: + The id of the tokenword disambiguation token. + """ + _kaldifst_maybe_raise() + + unigram_state = 0 + # check if 0 is the unigram state (has no outgoing epsilon arcs) + assert kaldifst.ArcIterator(g_fst, unigram_state).value.ilabel not in (0, g_fst.output_symbols.find("#0")) + + # we put tokenword self-loops in a separate state wrapped with a tokenword_disambig token + tokenword_disambig_id = g_fst.output_symbols.available_key() + tokenword_disambig = "#1" + g_fst.output_symbols.add_symbol(tokenword_disambig, tokenword_disambig_id) + tokenword_state = g_fst.add_state() + # we keep olabel !=0 to mark tokenword segments in the recognition results + g_fst.add_arc( + state=unigram_state, + arc=kaldifst.StdArc( + ilabel=tokenword_disambig_id, + olabel=tokenword_disambig_id, + weight=word_weight, + nextstate=tokenword_state, + ), + ) + g_fst.add_arc( + state=tokenword_state, + arc=kaldifst.StdArc( + ilabel=tokenword_disambig_id, + olabel=tokenword_disambig_id, + weight=0.0, + nextstate=unigram_state, + ), + ) + label = tokenword_disambig_id + 1 + for t in tokens: + if t != token_oov: + g_fst.add_arc( + state=tokenword_state, + arc=kaldifst.StdArc( + ilabel=label, + olabel=label, + weight=token_unigram_weight, + nextstate=tokenword_state, + ), + ) + g_fst.output_symbols.add_symbol(f"{t}{TW_BREAK}{tokenword_disambig}", label) + label += 1 + + return tokenword_disambig_id + + +def generate_lexicon_sentencepiece( + tokenizer: 'TokenizerSpec', + id2word: Dict[int, str], + oov: str = "", + add_epsilon: bool = False, + first_tokenword_id: int = -1, + disambig_pattern: str = re.compile(r"^#\d+$"), +) -> Lexicon: + """ + Generate a Lexicon using a SentencePiece tokenizer. + + Args: + tokenizer: + NeMo SentencePiece tokenizer. + + id2word: + Word index. + Mapping from word_id to word_str. + + oov: + Out Of Vocabulary word in lexicon. + + Returns: + Lexicon object. + """ + word2id = {v: k for k, v in id2word.items()} + backoff_disambig = "#0" + tokenword_disambig = "#1" + word_begin_mark = "▁" + + tokenword_mode = first_tokenword_id != -1 + if tokenword_mode: + words, tokenwords = [], [] + for k, v in id2word.items(): + if disambig_pattern.match(v): + continue + words.append(v) if k < first_tokenword_id else tokenwords.append(v) + else: + words, tokenwords = [v for v in id2word.values() if not disambig_pattern.match(v)], [] + + # Use encode to avoid OOV tokens + words_piece_ids = tokenizer.encode(words, out_type=int) + + # tokenizer.get_vocab() gives indices starting with 1 + maybe_add_one = int(add_epsilon) + maybe_subtract_one = int(not add_epsilon) + vocab = tokenizer.get_vocab() + id2token = { + v - maybe_subtract_one: LexiconUnit(k, "begin" if k.startswith(word_begin_mark) else "") + for k, v in vocab.items() + } + + # Introduce unk, blank, and the first disambig ids + unk_id = tokenizer.piece_to_id(oov) + maybe_add_one + id2token[unk_id] = LexiconUnit(oov, "unk") + # We assume blank to have the last output id of the neural network output + max_token_id = max(id2token.keys()) + id2token[max_token_id + 1] = LexiconUnit("", "blank") + id2token[max_token_id + 2] = LexiconUnit(backoff_disambig, "disambig_backoff") + if tokenword_mode: + id2token[max_token_id + 3] = LexiconUnit(tokenword_disambig, "disambig_tokenword") + if add_epsilon: + # insert first + id2token[0] = LexiconUnit("", "epsilon") + id2token = {k: v for k, v in sorted(id2token.items(), key=lambda item: item[0])} + + if tokenword_mode: + words += tokenwords + words_piece_ids += [[vocab[tw.rstrip(f"{TW_BREAK}{tokenword_disambig}")] - maybe_add_one] for tw in tokenwords] + + wordid2tokenid = defaultdict(list) + + for word, piece_ids in zip(words, words_piece_ids): + if word.startswith("<") and word != "": # not a real word, probably some tag + continue + elif word == "": # we do not need to tokelize + continue + else: + wordid2tokenid[word2id[word]].append([p + maybe_add_one for p in piece_ids]) + + lexicon = Lexicon(wordid2tokenid, id2word, id2token) + # state disambig purpose explicitly for further use + lexicon.id2word[lexicon.word2id[backoff_disambig]].mark = "disambig_backoff" + if tokenword_mode: + lexicon.id2word[lexicon.word2id[tokenword_disambig]].mark = "disambig_tokenword" + for tw in tokenwords: + lexicon.id2word[lexicon.word2id[tw]].mark = "tokenword" + return lexicon + + +def add_disambig_symbols(lexicon: Lexicon) -> Lexicon: + """ + Adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + Lexicon object. + + Returns: + Return Lexicon augmented with subseqence disambiguation symbols. + """ + + tokenword_mode = "#1" in lexicon.word2id + if tokenword_mode: + first_tokenword_id = lexicon.word2id["#1"] + 1 + last_used_disambig_id = lexicon.token2id["#1"] + else: + last_used_disambig_id = lexicon.token2id["#0"] + + # (1) Work out the count of each token-sequence in the lexicon. + count = defaultdict(int) + for _, token_ids in lexicon: + count[tuple(token_ids)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for word_id, token_ids in lexicon: + if tokenword_mode and word_id >= first_tokenword_id: + continue + token_ids = token_ids.copy() + token_ids.pop() + while token_ids: + issubseq[tuple(token_ids)] = 1 + token_ids.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + wordid2tokenid = defaultdict(list) + id2token = lexicon.id2token.copy() + + first_allowed_disambig = lexicon.num_disambigs + first_allowed_disambig_id = last_used_disambig_id + 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_id_of = defaultdict(int) + + for word_id, token_ids in lexicon: + token_key = tuple(token_ids) + assert len(token_key) > 0 + if issubseq[token_key] == 0 and count[token_key] == 1 or tokenword_mode and word_id >= first_tokenword_id: + wordid2tokenid[word_id].append(token_ids) + continue + + cur_disambig_id = last_used_disambig_id_of[token_key] + if cur_disambig_id == 0: + cur_disambig = first_allowed_disambig + cur_disambig_id = first_allowed_disambig_id + else: + cur_disambig = int(id2token[cur_disambig_id].name.lstrip("#")) + 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + cur_disambig_id = max(id2token.keys()) + 1 + id2token[cur_disambig_id] = LexiconUnit(f"#{max_disambig}", "disambig_subsequence") + last_used_disambig_id_of[token_key] = cur_disambig_id + wordid2tokenid[word_id].append(token_ids + [cur_disambig_id]) + return Lexicon(wordid2tokenid, lexicon.id2word, id2token) + + +def make_lexicon_fst_no_silence( + lexicon: Lexicon, + attach_symbol_table: bool = True, +) -> 'kaldifst.StdVectorFst': + """ + Compiles a Lexicon into a lexicon WFST (L.fst). + + See also make_lexicon_fst.py from kaldi. + + Args: + lexicon: + Lexicon object. + + Returns: + Kaldi-type lexicon WFST. + """ + _kaldifst_maybe_raise() + + backoff_disambig = "#0" + tokenword_disambig = "#1" + tokenword_mode = tokenword_disambig in lexicon.word2id + if tokenword_mode: + first_tokenword_id = lexicon.word2id[tokenword_disambig] + 1 + + fst = kaldifst.StdVectorFst() + start_state = fst.add_state() + fst.start = start_state + fst.set_final(state=start_state, weight=0) + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=lexicon.token2id[backoff_disambig], + olabel=lexicon.word2id[backoff_disambig], + weight=0, + nextstate=start_state, + ), + ) + if tokenword_mode: + tokenword_state_begin = fst.add_state() + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=lexicon.token2id[tokenword_disambig], + olabel=lexicon.word2id[tokenword_disambig], + weight=0, + nextstate=tokenword_state_begin, + ), + ) + + for word_id, token_ids in lexicon: + cur_state = start_state + + if not tokenword_mode or word_id < first_tokenword_id - 1: + for i, token_id in enumerate(token_ids[:-1]): + next_state = fst.add_state() + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=token_id, + olabel=word_id if i == 0 else 0, + weight=0, + nextstate=next_state, + ), + ) + cur_state = next_state + i = len(token_ids) - 1 # note: i == -1 if tokens is empty. + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=token_ids[-1] if i >= 0 else 0, + olabel=word_id if i <= 0 else 0, + weight=0, + nextstate=start_state, + ), + ) + if tokenword_mode: + tokenword_begin, tokenword_other = [], [] + for word_id in range(first_tokenword_id, max(lexicon.id2word) + 1): + token_id = lexicon.token2id[lexicon.id2word[word_id].name.rstrip(f"{TW_BREAK}{tokenword_disambig}")] + token_unit = lexicon.id2token[token_id] + if token_unit.mark.startswith("begin"): + tokenword_begin.append((token_id, word_id)) + elif token_unit.mark == "": + tokenword_other.append((token_id, word_id)) + else: + raise RuntimeError(f"Unexpected mark `{token_unit.mark}` for tokenword `{token_unit.name}`") + + tokenword_state_main = fst.add_state() + for token_id, word_id in tokenword_begin: + fst.add_arc( + state=tokenword_state_begin, + arc=kaldifst.StdArc( + ilabel=token_id, + olabel=word_id, + weight=0, + nextstate=tokenword_state_main, + ), + ) + tokenword_state_end = fst.add_state() + for token_id, word_id in tokenword_other: + fst.add_arc( + state=tokenword_state_main, + arc=kaldifst.StdArc( + ilabel=token_id, + olabel=word_id, + weight=0, + nextstate=tokenword_state_main, + ), + ) + fst.add_arc( + state=tokenword_state_main, + arc=kaldifst.StdArc( + ilabel=token_id, + olabel=word_id, + weight=0, + nextstate=tokenword_state_end, + ), + ) + fst.add_arc( + state=tokenword_state_end, + arc=kaldifst.StdArc( + ilabel=lexicon.token2id[tokenword_disambig], + olabel=lexicon.word2id[tokenword_disambig], + weight=0, + nextstate=start_state, + ), + ) + + if attach_symbol_table: + isym = kaldifst.SymbolTable() + for p, i in lexicon.token2id.items(): + isym.add_symbol(symbol=p, key=i) + fst.input_symbols = isym + + osym = kaldifst.SymbolTable() + for w, i in lexicon.word2id.items(): + osym.add_symbol(symbol=w, key=i) + fst.output_symbols = osym + + kaldifst.arcsort(fst, sort_type="ilabel") + return fst + + +def build_topo( + name: str, token2id: Dict[str, int], with_self_loops: bool = True, attach_symbol_table: bool = True +) -> 'kaldifst.StdVectorFst': + """Helper function to build a topology WFST (T.fst). + + Args: + name: + Topology name. Choices: default, compact, minimal + + token2id: + Token index. + Mapping from token_str to token_id. + + with_self_loops: + Whether to add token-to-epsilon self-loops to the topology. + + attach_symbol_table: + Whether to attach the token names for indices of the returned WFST. + + Returns: + Kaldi-type topology WFST. + """ + _kaldifst_maybe_raise() + + if name == "default": + fst = build_default_topo(token2id, with_self_loops) + elif name == "compact": + fst = build_compact_topo(token2id, with_self_loops) + elif name == "minimal": + fst = build_minimal_topo(token2id) + else: + raise ValueError(f"Unknown topo name: {name}") + + if attach_symbol_table: + isym = kaldifst.SymbolTable() + for t, i in token2id.items(): + isym.add_symbol(symbol=t, key=i) + fst.input_symbols = isym + fst.output_symbols = fst.input_symbols.copy() + return fst + + +def build_default_topo(token2id: Dict[str, int], with_self_loops: bool = True) -> 'kaldifst.StdVectorFst': + """Build the default (correct) CTC topology.""" + _kaldifst_maybe_raise() + + disambig_pattern = re.compile(r"^#\d+$") + blank_id = token2id[""] + fst = kaldifst.StdVectorFst() + start_state = fst.add_state() + fst.start = start_state + fst.set_final(state=start_state, weight=0) + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=blank_id, + olabel=0, + weight=0, + nextstate=start_state, # token2id[""] is always 0 + ), + ) + + disambig_ids = [] + token_ids = {} + for s, i in token2id.items(): + if s == "" or s == "": + continue + elif disambig_pattern.match(s): + disambig_ids.append(i) + else: + state = fst.add_state() + fst.set_final(state=state, weight=0) + token_ids[state] = i + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=i, + olabel=i, + weight=0, + nextstate=state, + ), + ) + if with_self_loops: + fst.add_arc( + state=state, + arc=kaldifst.StdArc( + ilabel=i, + olabel=0, + weight=0, + nextstate=state, # token2id[""] is always 0 + ), + ) + fst.add_arc( + state=state, + arc=kaldifst.StdArc( + ilabel=blank_id, + olabel=0, + weight=0, + nextstate=start_state, # token2id[""] is always 0 + ), + ) + + for istate in kaldifst.StateIterator(fst): + if istate > 0: + for ostate in kaldifst.StateIterator(fst): + if ostate > 0 and istate != ostate: + label = token_ids[ostate] + fst.add_arc( + state=istate, + arc=kaldifst.StdArc( + ilabel=label, + olabel=label, + weight=0, + nextstate=ostate, + ), + ) + for disambig_id in disambig_ids: + fst.add_arc( + state=istate, + arc=kaldifst.StdArc( + ilabel=0, + olabel=disambig_id, + weight=0, + nextstate=istate, # token2id[""] is always 0 + ), + ) + + return fst + + +def build_compact_topo(token2id: Dict[str, int], with_self_loops: bool = True) -> 'kaldifst.StdVectorFst': + """Build the Compact CTC topology.""" + _kaldifst_maybe_raise() + + disambig_pattern = re.compile(r"^#\d+$") + blank_id = token2id[""] + fst = kaldifst.StdVectorFst() + start_state = fst.add_state() + fst.start = start_state + fst.set_final(state=start_state, weight=0) + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=blank_id, + olabel=0, + weight=0, + nextstate=start_state, # token2id[""] is always 0 + ), + ) + + for s, i in token2id.items(): + if s == "" or s == "": + continue + elif disambig_pattern.match(s): + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=i, + weight=0, + nextstate=start_state, # token2id[""] is always 0 + ), + ) + else: + state = fst.add_state() + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=i, + olabel=i, + weight=0, + nextstate=state, + ), + ) + if with_self_loops: + fst.add_arc( + state=state, + arc=kaldifst.StdArc( + ilabel=i, + olabel=0, + weight=0, + nextstate=state, # token2id[""] is always 0 + ), + ) + fst.add_arc( + state=state, + arc=kaldifst.StdArc( + ilabel=0, # token2id[""] is always 0 + olabel=0, # token2id[""] is always 0 + weight=0, + nextstate=start_state, + ), + ) + + return fst + + +def build_minimal_topo(token2id: Dict[str, int]) -> 'kaldifst.StdVectorFst': + """Build the Minimal CTC topology.""" + _kaldifst_maybe_raise() + + disambig_pattern = re.compile(r"^#\d+$") + blank_id = token2id[""] + fst = kaldifst.StdVectorFst() + start_state = fst.add_state() + fst.start = start_state + fst.set_final(state=start_state, weight=0) + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=blank_id, + olabel=0, + weight=0, + nextstate=start_state, # token2id[""] is always 0 + ), + ) + + for s, i in token2id.items(): + if s == "" or s == "": + continue + elif disambig_pattern.match(s): + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=i, + weight=0, + nextstate=start_state, # token2id[""] is always 0 + ), + ) + else: + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=i, + olabel=i, + weight=0, + nextstate=start_state, + ), + ) + + return fst + + +def mkgraph_ctc_ov( + tokenizer: 'TokenizerSpec', + lm_path: Union[Path, str], + topology_name: str = "default", + write_tlg_path: Optional[Union[Path, str]] = None, + open_vocabulary: bool = False, + open_vocabulary_weights: Tuple[float, float] = (2.0, 4.0), + target: str = "kaldi", # "kaldi", "k2" +) -> Tuple[Union['kaldifst.StdVectorFst', 'k2.Fsa'], int]: + """ + Builds a decoding WFST (TLG.fst or TLG.pt). + + See also mkgraph.sh from kaldi. + + Args: + tokenizer: + NeMo SentencePiece tokenizer. + + lm_path: + Path to the ARPA LM file. + + topology_name: + Topology name. Choices: default, compact, minimal. + + write_tlg_path: + Where to buffer the TLG. + + open_vocabulary: + Whether to build a decoding WFST suitable for the open vocabulary decoding. + + open_vocabulary_weights: + Pair of weights (oov_word_weight, token_unigram_weight). + + target: + What type to build the WFST for. Choices: kaldi, k2. + + Returns: + A pair of kaldi- or k2-type decoding WFST and its id of the tokenword disambiguation token. + """ + _kaldifst_maybe_raise() + + logging.info("Compiling G.fst ...") + G = arpa2fst(lm_path) + if open_vocabulary: + # in-place for g_fst + tokenword_disambig_id = add_tokenwords_( + g_fst=G, + tokens=tokenizer.tokenizer.get_vocab().keys(), + word_weight=open_vocabulary_weights[0], + token_unigram_weight=open_vocabulary_weights[1], + ) + else: + tokenword_disambig_id = -1 + + logging.info("Building L.fst ...") + id2word = {int(line.split("\t")[1]): line.split("\t")[0] for line in str(G.output_symbols).strip().split("\n")} + lexicon = generate_lexicon_sentencepiece( + tokenizer.tokenizer, id2word, add_epsilon=True, first_tokenword_id=tokenword_disambig_id + ) + lexicon_disambig = add_disambig_symbols(lexicon) + + L = make_lexicon_fst_no_silence(lexicon_disambig) + kaldifst.arcsort(L, sort_type="olabel") + + logging.info("Building LG.fst ...") + LG = kaldifst.compose(L, G) + kaldifst.determinize_star(LG) + kaldifst.minimize_encoded(LG) + kaldifst.arcsort(LG, sort_type="ilabel") + + logging.info("Building TLG.fst ...") + T = build_topo(topology_name, lexicon_disambig.token2id) + kaldifst.arcsort(T, sort_type="olabel") + TLG = kaldifst.compose(T, LG) + + if target == "kaldi": + if write_tlg_path: + logging.info(f"Buffering TLG.fst into {write_tlg_path} ...") + TLG.write(write_tlg_path) + elif target == "k2": + logging.info("Converting TLG.fst to k2 ...") + import torch + + from nemo.core.utils.k2_guard import k2 + + blank_id = [i for i, t in lexicon_disambig.id2token.items() if t.mark == "blank"][0] + first_token_disambig_id = [i for i, t in lexicon_disambig.id2token.items() if t.mark == "disambig_backoff"][0] + word_disambig_id = lexicon_disambig.word2id[lexicon_disambig.id2token[first_token_disambig_id].name] + assert lexicon_disambig.id2word[word_disambig_id].mark == "disambig_backoff" + input_symbols = "\n".join( + [f"{k} {v - 1}" for k, v in lexicon_disambig.token2id.items() if 0 < v < first_token_disambig_id] + ) + output_symbols = str(TLG.output_symbols) + TLG.input_symbols = None + TLG.output_symbols = None + # k2 does not support torch.inference_mode enabled + with torch.inference_mode(False): + TLG = k2.Fsa.from_openfst(TLG.to_str(show_weight_one=True), acceptor=False) + TLG.labels[TLG.labels >= first_token_disambig_id] = blank_id + TLG.aux_labels[TLG.aux_labels.values == word_disambig_id] = 0 + TLG.__dict__["_properties"] = None + TLG = k2.arc_sort(k2.connect(k2.remove_epsilon(TLG))) + TLG.labels[TLG.labels > 0] = TLG.labels[TLG.labels > 0] - 1 + TLG.__dict__["_properties"] = None + TLG.labels_sym = k2.SymbolTable.from_str(input_symbols) + TLG.aux_labels_sym = k2.SymbolTable.from_str(output_symbols) + TLG = k2.arc_sort(TLG) + if write_tlg_path: + logging.info(f"Buffering TLG.pt into {write_tlg_path} ...") + torch.save(TLG.as_dict(), write_tlg_path) + else: + raise ValueError(f"Unsupported target: `{target}`") + + return TLG, tokenword_disambig_id + + +class KaldiFstMask(Enum): + Acceptor = 65536 + Error = 4 + TopSorted = 274877906944 + Acyclic = 34359738368 + IlabelSorted = 268435456 + OlabelSorted = 1073741824 + IlabelDeterministic = 262144 + OlabelDeterministic = 1048576 + HasEpsilons = 4194304 + HasIEpsilons = 16777216 + Accessible = 1099511627776 + Coaccessible = 4398046511104 + Weighted = 4294967296 + + +class LatticeProperties(NamedTuple): + Acceptor: bool + Valid: bool + Nonempty: bool + TopSorted: bool + Acyclic: bool + ArcSorted: bool + Deterministic: bool + EpsilonFree: bool + InputEpsilonFree: bool + Connected: bool + Weighted: bool + + +class AbstractLattice(ABC): + """A lattice wrapper with high-level capabilities.""" + + def __init__(self, lattice: Any): + self._lattice = lattice + self._properties = None + + @abstractmethod + def as_tensor(self) -> 'torch.Tensor': + """Represents the lattice as a tensor. + + Returns: + torch.Tensor + """ + pass + + @abstractmethod + def draw( + self, filename: Optional[Union[Path, str]] = None, title: Optional[Union[Path, str]] = None, zoom: float = 1.0 + ) -> Union['graphviz.Digraph', 'IPython.display.HTML']: + """Render FSA as an image via graphviz, and return the Digraph object; and optionally save to file filename. + filename must have a suffix that graphviz understands, such as pdf, svg or png. + + Note: + You need to install graphviz to use this function:: + + ./scripts/installers/install_graphviz.sh + + Args: + filename: + Filename to (optionally) save to, e.g. ‘foo.png’, ‘foo.svg’, ‘foo.png’. + + title: + Title to be displayed in image, e.g. ‘A simple lattice example’. + + zoom: + Zoom-in lattice in IPython notebook (needed for large lattices). + + Returns: + graphviz.Digraph or IPython.display.HTML + """ + pass + + @abstractmethod + def edit_distance(self, reference_sequence: List[int]) -> int: + """Get the edit distance from a reference sequence to the lattice. + + Args: + reference_sequence: + List of word- or token-ids. + + Returns: + Number of edits. + """ + + @property + def lattice(self): + self._properties = None + return self._lattice + + @abstractproperty + def properties(self) -> LatticeProperties: + pass + + @abstractproperty + def symbol_table(self) -> Optional[Dict[int, str]]: + pass + + @abstractproperty + def auxiliary_tables(self) -> Optional[Tuple[Any]]: + pass + + +class KaldiWordLattice(AbstractLattice): + """A Kaldi lattice wrapper with high-level capabilities.""" + + def __init__( + self, + lattice: 'kaldifst.Lattice', + symbol_table: Optional[Dict[int, str]] = None, + auxiliary_tables: Optional[Dict[str, Any]] = None, + ): + _kaldifst_maybe_raise() + + if not isinstance(lattice, kaldifst.Lattice): + raise ValueError(f"Wrong lattice type: `{type(lattice)}`") + super().__init__(lattice) + + kaldi_symbols2dict = lambda symbols: { + int(line.split("\t")[1]): line.split("\t")[0] for line in str(symbols).strip().split("\n") + } + self._symbol_table = None + # most likely lattice will have empty input_symbols + if symbol_table is not None: + self._symbol_table = symbol_table + elif self._lattice.output_symbols is not None: + # we suppose that lattice.input_symbols will not be changed + self._symbol_table = kaldi_symbols2dict(self._lattice.output_symbols) + + self._auxiliary_tables = None + if auxiliary_tables is not None: + attributes, values = list(auxiliary_tables.keys()), list(auxiliary_tables.values()) + if "input_symbols" not in attributes and self._lattice.input_symbols is not None: + # rare but possible case + attributes.append("input_symbols") + values.append(kaldi_symbols2dict(self._lattice.input_symbols)) + self._auxiliary_tables = namedtuple("KaldiAuxiliaryTables", attributes)(*values) + elif self._lattice.input_symbols is not None: + self._auxiliary_tables = namedtuple("KaldiAuxiliaryTables", "input_symbols")( + kaldi_symbols2dict(self._lattice.input_symbols) + ) + + @property + def properties(self) -> LatticeProperties: + if self._properties is None: + acceptor = self._lattice.properties(KaldiFstMask.Acceptor.value, True) == KaldiFstMask.Acceptor.value + valid = self._lattice.properties(KaldiFstMask.Error.value, True) != KaldiFstMask.Error.value + nonempty = self._lattice.num_states > 0 + top_sorted = self._lattice.properties(KaldiFstMask.TopSorted.value, True) == KaldiFstMask.TopSorted.value + acyclic = self._lattice.properties(KaldiFstMask.Acyclic.value, True) == KaldiFstMask.Acyclic.value + arc_sorted = ( + self._lattice.properties(KaldiFstMask.IlabelSorted.value, True) == KaldiFstMask.IlabelSorted.value + and self._lattice.properties(KaldiFstMask.OlabelSorted.value, True) == KaldiFstMask.OlabelSorted.value + ) + deterministic = ( + self._lattice.properties(KaldiFstMask.IlabelDeterministic.value, True) + == KaldiFstMask.IlabelDeterministic.value + and self._lattice.properties(KaldiFstMask.OlabelDeterministic.value, True) + == KaldiFstMask.OlabelDeterministic.value + ) + epsilon_free = ( + self._lattice.properties(KaldiFstMask.HasEpsilons.value, True) != KaldiFstMask.HasEpsilons.value + ) + input_epsilon_free = ( + self._lattice.properties(KaldiFstMask.HasIEpsilons.value, True) != KaldiFstMask.HasIEpsilons.value + ) + connected = ( + self._lattice.properties(KaldiFstMask.Accessible.value, True) == KaldiFstMask.Accessible.value + and self._lattice.properties(KaldiFstMask.Coaccessible.value, True) == KaldiFstMask.Coaccessible.value + ) + weighted = self._lattice.properties(KaldiFstMask.Weighted.value, True) == KaldiFstMask.Weighted.value + self._properties = LatticeProperties( + Acceptor=acceptor, + Valid=valid, + Nonempty=nonempty, + TopSorted=top_sorted, + Acyclic=acyclic, + ArcSorted=arc_sorted, + Deterministic=deterministic, + EpsilonFree=epsilon_free, + InputEpsilonFree=input_epsilon_free, + Connected=connected, + Weighted=weighted, + ) + return self._properties + + @property + def symbol_table(self) -> Optional[Dict[int, str]]: + return self._symbol_table + + @property + def auxiliary_tables(self) -> Optional[Tuple[Any]]: + return self._auxiliary_tables + + def as_tensor(self) -> 'torch.Tensor': + """Represents the lattice as a tensor. + + Returns: + torch.Tensor + """ + raise NotImplementedError("Tensor representation is not supported yet.") + + def edit_distance(self, reference_sequence: List[int]) -> int: + """Get the edit distance from a reference sequence to the lattice. + + Args: + reference_sequence: + List of word- or token-ids. + + Returns: + Number of edits. + """ + _kaldifst_maybe_raise() + + if not self.properties.InputEpsilonFree: + logging.warning(f"Lattice contains input epsilons. Edit distance calculations may not be accurate.") + if not all(reference_sequence): + raise ValueError(f"reference_sequence contains zeros, which is not allowed.") + ref = levenshtein_graph_kaldi(kaldifst.make_linear_acceptor(reference_sequence)) + hyp = levenshtein_graph_kaldi(self._lattice) + kaldifst.invert(hyp) + ali_fst = kaldifst.compose(hyp, ref) + succeeded, _, _, total_weight = kaldifst.get_linear_symbol_sequence(kaldifst.shortest_path(ali_fst)) + if not succeeded: + raise RuntimeError("Something went wrong while calculating edit_distance. Please check input manually.") + return round(total_weight.value) + + def draw( + self, filename: Optional[Union[Path, str]] = None, title: Optional[Union[Path, str]] = None, zoom: float = 1.0 + ) -> Union['graphviz.Digraph', 'IPython.display.HTML']: + """Render FSA as an image via graphviz, and return the Digraph object; and optionally save to file filename. + filename must have a suffix that graphviz understands, such as pdf, svg or png. + + Note: + You need to install graphviz to use this function:: + + ./scripts/installers/install_graphviz.sh + + Args: + filename: + Filename to (optionally) save to, e.g. ‘foo.png’, ‘foo.svg’, ‘foo.png’. + + title: + Title to be displayed in image, e.g. ‘A simple lattice example’. + + zoom: + Zoom-in lattice in IPython notebook (needed for large lattices). + + Returns: + graphviz.Digraph or IPython.display.HTML + """ + _kaldifst_maybe_raise() + _graphviz_maybe_raise() + + isym, osym = None, None + if self._symbol_table: + osym = kaldifst.SymbolTable() + for i, w in self._symbol_table.items(): + osym.add_symbol(symbol=w, key=i) + + if ( + self._auxiliary_tables + and hasattr(self._auxiliary_tables, "input_symbols") + and self._auxiliary_tables.input_symbols + ): + isym = kaldifst.SymbolTable() + for i, t in self._auxiliary_tables.input_symbols.items(): + isym.add_symbol(symbol=t, key=i) + + fst_dot = kaldifst.draw( + self._lattice, acceptor=False, portrait=True, isymbols=isym, osymbols=osym, show_weight_one=True + ) + source = graphviz.Source(fst_dot) + source_lines = str(source).splitlines() + # Remove 'digraph tree {' + source_lines.pop(0) + # Remove the closing brackets '}' + source_lines.pop(-1) + graph_attr = { + 'rankdir': 'LR', + 'size': '8.5,11', + 'center': '1', + 'orientation': 'Portrait', + 'ranksep': '0.4', + 'nodesep': '0.25', + 'margin': '0.0', + } + if title is not None: + graph_attr['label'] = title + digraph = graphviz.Digraph(graph_attr=graph_attr) + digraph.body += source_lines + if filename: + _, extension = os.path.splitext(filename) + if extension == '' or extension[0] != '.': + raise ValueError(f"Filename needs to have a suffix like .png, .pdf, .svg, or .gv: `{filename}`") + with tempfile.TemporaryDirectory() as tmp_dir: + temp_fn = digraph.render(filename='temp', directory=tmp_dir, format=extension[1:], cleanup=True) + + shutil.move(temp_fn, filename) + if _is_notebook(): + import warnings + + from IPython.display import HTML + + with tempfile.TemporaryDirectory() as tmp_dir: + temp_fn = digraph.render(filename='temp', directory=tmp_dir, format="svg", cleanup=True) + svg, (width, height) = _svg_srcdoc_resize(temp_fn, zoom) + # IFrame requires src file to be present when rendering + # so we use HTML with iframe srcdoc instead + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return HTML( + f"""""" + ) + return digraph + + +def _is_notebook() -> bool: + try: + shell = get_ipython().__class__.__name__ + if shell == 'ZMQInteractiveShell' or 'Shell': + return True # Jupyter notebook, Google Colab notebook, or qtconsole + elif shell == 'TerminalInteractiveShell': + return False # Terminal running IPython + else: + return False # Other type + except NameError: + return False # Probably standard Python interpreter + + +def _svg_srcdoc_resize(filename: Union[Path, str], zoom: float) -> Tuple[str, Tuple[int, int]]: + with open(filename, "rt", encoding="utf-8") as f: + line = f.readline() + while not line.startswith(" 'kaldifst.StdFst': + """Construct the levenshtein graph from a kaldi-type WFST or a lattice. + + See also levenshtein_graph from k2. + + Args: + fst: + Kaldi-type source WFST or lattice. + + ins_del_score: + Insertion and deletion penalty. + Should be more than 0.5 for substitutions to be preferred over insertions/deletions, or less otherwise. + + Returns: + Kaldi-type levenshtein WFST. + """ + _kaldifst_maybe_raise() + + if fst.properties(KaldiFstMask.Acceptor.value, True) != KaldiFstMask.Acceptor.value: + logging.warning( + "Levenshtein graph construction is not safe for WFSTs with different input and output symbols." + ) + if fst.properties(KaldiFstMask.Acyclic.value, True) != KaldiFstMask.Acyclic.value: + raise ValueError("Levenshtein graph is not defined for WFSTs with cycles.") + if isinstance(fst, kaldifst.StdFst): + lfst = fst.copy(safe=True) + elif isinstance(fst, kaldifst.Lattice): + # dropping lattice weights + lfst = kaldifst.compile(re.sub("[-\d.]+,[-\d.]+", "0", fst.to_str(show_weight_one=True))) + else: + raise ValueError(f"Levenshtein graph building is not supported for the type `{type(fst)}`.") + sub_score = 0.5 + eps = 0 + for state in kaldifst.StateIterator(lfst): + # epsilon self-loop for insertions and deletions + arcs_to_add = [ + kaldifst.StdArc( + ilabel=eps, + olabel=eps, + weight=ins_del_score, + nextstate=state, + ) + ] + for arc in kaldifst.ArcIterator(lfst, state): + # epsilon-to-ilabel arc for substitutions + arcs_to_add.append( + kaldifst.StdArc( + ilabel=eps, + olabel=arc.ilabel, + weight=sub_score, + nextstate=arc.nextstate, + ) + ) + # zero weight for correct ids (redundant for lattices) + arc.weight = 0.0 + for arc in arcs_to_add: + lfst.add_arc(state=state, arc=arc) + kaldifst.arcsort(lfst) + return lfst + + +def load_word_lattice( + lat_filename: Union[Path, str], id2word: Optional[Dict[int, str]] = None, id2token: Optional[Dict[int, str]] = None +) -> Dict[str, KaldiWordLattice]: + """Helper function to load riva-decoder recognition lattices. + + Args: + lat_filename: + Path to the riva-decoder recognition lattice file. + + id2word: + Word index. + Mapping from word_id to word_str. + + id2token: + Token index. + Mapping from token_id to token_str. + + Returns: + Dictionary with lattice names and corresponding lattices in KaldiWordLattice format. + """ + _kaldifst_maybe_raise() + + lattice_dict = {} + lattice = None + max_state = 0 + token_seq_list = [] + with open(lat_filename, "rt") as f: + for line in f.readlines(): + line_items = line.strip().split() + line_len = len(line_items) + if line_len == 0: # end of lattice + token_seq_list = [] + lattice = None + max_state = 0 + elif line_len == 1: # lattice identifier + assert lattice is None + assert max_state == 0 + assert len(token_seq_list) == 0 + lat_id = line_items[0] + lattice = kaldifst.Lattice() + lattice_dict[lat_id] = KaldiWordLattice( + lattice=lattice, + symbol_table=id2word, + auxiliary_tables={"token_seq_list": token_seq_list, "input_symbols": id2token}, + ) + start = lattice.add_state() + lattice.start = start + max_state += 1 + elif line_len in (3, 4): # arc + if line_len == 4: # regular arc + state, next_state, label = [int(i) for i in line_items[:-1]] + trunk = line_items[-1].split(',') + graph_cost, acoustic_cost = [float(i) for i in trunk[:-1]] + else: # arc without weight + logging.warning( + f"""An arc without weight is detected for lattice `{lat_id}`. + Weights and token sequences will be set trivially.""" + ) + state, next_state, label = [int(i) for i in line_items] + trunk = [""] + graph_cost, acoustic_cost = 0.0, 0.0 + if next_state >= max_state: + for i in range(max_state, next_state + 1): + lattice.add_state() + max_state = next_state + 1 + ark = kaldifst.LatticeArc( + ilabel=label, + olabel=label, + weight=kaldifst.LatticeWeight(graph_cost=graph_cost, acoustic_cost=acoustic_cost), + nextstate=next_state, + ) + lattice.add_arc(state=state, arc=ark) + token_seq_list.append((ark, [int(i) for i in trunk[-1].split(TW_BREAK)] if trunk[-1] != "" else [])) + elif line_len == 2: # final state + state = int(line_items[0]) + trunk = line_items[-1].split(',') + graph_cost, acoustic_cost = [float(i) for i in trunk[:-1]] + lattice.set_final( + state=state, weight=kaldifst.LatticeWeight(graph_cost=graph_cost, acoustic_cost=acoustic_cost) + ) + else: + raise RuntimeError(f"Broken line: `{line}`") + return lattice_dict diff --git a/nemo/collections/audio/data/audio_to_audio_lhotse.py b/nemo/collections/audio/data/audio_to_audio_lhotse.py index 27d8a0ed28d7..d8978c19d692 100644 --- a/nemo/collections/audio/data/audio_to_audio_lhotse.py +++ b/nemo/collections/audio/data/audio_to_audio_lhotse.py @@ -44,19 +44,29 @@ class LhotseAudioToTargetDataset(torch.utils.data.Dataset): EMBEDDING_KEY = "embedding_vector" def __getitem__(self, cuts: CutSet) -> dict[str, torch.Tensor]: - src_audio, src_audio_lens = collate_audio(cuts) + # In the rare case, the collate_audio function would raise the FileSeek error when loading .flac (https://github.com/bastibe/python-soundfile/issues/274) + # A workaround is to use fault_tolerant and skip failed data, resulting in a smaller batch size for the few problematic cases. + src_audio, src_audio_lens, retained_padded_cuts = collate_audio(cuts, fault_tolerant=True) ans = { "input_signal": src_audio, "input_length": src_audio_lens, } - if _key_available(cuts, self.TARGET_KEY): - tgt_audio, tgt_audio_lens = collate_audio(cuts, recording_field=self.TARGET_KEY) + # keep only the first non-padding cuts + retained_cuts = [ + cut._first_non_padding_cut if isinstance(cut, MixedCut) else cut for cut in retained_padded_cuts + ] + retained_cuts = CutSet.from_cuts(retained_cuts) + + if _key_available(retained_cuts, self.TARGET_KEY): + # TODO: use fault_tolerant=True for robust loading of target + tgt_audio, tgt_audio_lens = collate_audio(retained_cuts, recording_field=self.TARGET_KEY) ans.update(target_signal=tgt_audio, target_length=tgt_audio_lens) - if _key_available(cuts, self.REFERENCE_KEY): - ref_audio, ref_audio_lens = collate_audio(cuts, recording_field=self.REFERENCE_KEY) + if _key_available(retained_cuts, self.REFERENCE_KEY): + # TODO: use fault_tolerant=True for robust loading of target + ref_audio, ref_audio_lens = collate_audio(retained_cuts, recording_field=self.REFERENCE_KEY) ans.update(reference_signal=ref_audio, reference_length=ref_audio_lens) if _key_available(cuts, self.EMBEDDING_KEY): - emb = collate_custom_field(cuts, field=self.EMBEDDING_KEY) + emb = collate_custom_field(retained_cuts, field=self.EMBEDDING_KEY) ans.update(embedding_signal=emb) return ans diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index ef9ce648f1a2..e1732c1658b7 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -483,4 +483,35 @@ def on_after_backward(self): if valid_gradients < 1: logging.warning('detected inf or nan values in gradients! Setting gradients to zero.') - self.zero_grad() + self.zero_grad(set_to_none=False) + + def configure_callbacks(self): + """ + Create an callback to add audio/spectrogram into tensorboard & wandb. + """ + self.log_config = self.cfg.get("log_config", None) + if not self.log_config: + return [] + + log_callbacks = [] + from nemo.collections.audio.parts.utils.callbacks import SpeechEnhancementLoggingCallback + + if isinstance(self._validation_dl, List): + data_loaders = self._validation_dl + else: + data_loaders = [self._validation_dl] + + for data_loader_idx, data_loader in enumerate(data_loaders): + log_callbacks.append( + SpeechEnhancementLoggingCallback( + data_loader=data_loader, + data_loader_idx=data_loader_idx, + loggers=self.trainer.loggers, + log_tensorboard=self.log_config.log_tensorboard, + log_wandb=self.log_config.log_wandb, + sample_rate=self.sample_rate, + max_utts=self.log_config.get("max_utts", None), + ) + ) + + return log_callbacks diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index e7fbc9023117..cd9f47b98096 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -30,6 +30,7 @@ 'ScoreBasedGenerativeAudioToAudioModel', 'PredictiveAudioToAudioModel', 'SchroedingerBridgeAudioToAudioModel', + 'FlowMatchingAudioToAudioModel', ] @@ -618,6 +619,274 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = return {f'{tag}_loss': loss} +class FlowMatchingAudioToAudioModel(AudioToAudioModel): + """This models uses a flow matching process to generate + an encoded representation of the enhanced signal. + + The model consists of the following blocks: + - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform) + - estimator: neural model, estimates a score for the diffusion process + - flow: ordinary differential equation (ODE) defining a flow and a vector field. + - sampler: sampler for the inference process, estimates coefficients of the target signal + - decoder: transforms sampler output into the time domain (synthesis transform) + - ssl_pretrain_masking: if it is defined, perform the ssl pretrain masking for self reconstruction in the training process + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate + + # Setup processing modules + self.encoder = self.from_config_dict(self._cfg.encoder) + self.decoder = self.from_config_dict(self._cfg.decoder) + + # Neural estimator + self.estimator = self.from_config_dict(self._cfg.estimator) + + # Flow + self.flow = self.from_config_dict(self._cfg.flow) + + # Sampler + self.sampler = hydra.utils.instantiate(self._cfg.sampler, estimator=self.estimator) + + # probability that the conditional input will be feed into the + # estimator in the training stage + self.p_cond = self._cfg.get('p_cond', 1.0) + + # Self-Supervised Pretraining + if self._cfg.get('ssl_pretrain_masking') is not None: + logging.debug('SSL-pretrain_masking is found and will be initialized') + self.ssl_pretrain_masking = self.from_config_dict(self._cfg.ssl_pretrain_masking) + else: + self.ssl_pretrain_masking = None + + # Normalization + self.normalize_input = self._cfg.get('normalize_input', False) + + # Metric evaluation + self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics') + + if self.max_utts_evaluation_metrics is not None: + logging.warning( + 'Metrics will be evaluated on first %d examples of the evaluation datasets.', + self.max_utts_evaluation_metrics, + ) + + # Regularization + self.eps = self._cfg.get('eps', 1e-8) + + # Setup optional Optimization flags + self.setup_optimization_flags() + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tdoing SSL-pretraining: %s', (self.ssl_pretrain_masking is not None)) + logging.debug('\tp_cond: %s', self.p_cond) + logging.debug('\tnormalize_input: %s', self.normalize_input) + logging.debug('\tloss: %s', self.loss) + logging.debug('\teps: %s', self.eps) + + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return { + "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @typecheck() + @torch.inference_mode() + def forward(self, input_signal, input_length=None): + """Forward pass of the model to generate samples from the target distribution. + + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + + Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. + """ + batch_length = input_signal.size(-1) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + + # Encoder + encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) + + if self.p_cond == 0: + encoded = torch.zeros_like(encoded) + elif self.ssl_pretrain_masking is not None: + encoded = self.ssl_pretrain_masking(input_spec=encoded, length=encoded_length) + + init_state = torch.randn_like(encoded) * self.flow.sigma_start + + # Sampler + generated, generated_length = self.sampler( + state=init_state, estimator_condition=encoded, state_length=encoded_length + ) + + # Decoder + output, output_length = self.decoder(input=generated, input_length=generated_length) + + if self.normalize_input: + # rescale to the original scale + output = output * norm_scale + + # Trim or pad the estimated signal to match input length + output = self.match_batch_length(input=output, batch_length=batch_length) + + return output, output_length + + @typecheck( + input_types={ + "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_length": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "loss": NeuralType(None, LossType()), + }, + ) + def _step(self, target_signal, input_signal, input_length=None): + batch_size = target_signal.size(0) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + # scale the target signal + target_signal = target_signal / (norm_scale + self.eps) + + # Apply encoder to both target and the input + input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length) + target_enc, _ = self.encoder(input=target_signal, input_length=input_length) + + # Self-Supervised Pretraining + if self.ssl_pretrain_masking is not None: + input_enc = self.ssl_pretrain_masking(input_spec=input_enc, length=input_enc_len) + + # Drop off conditional inputs (input_enc) with (1 - p_cond) probability. + # The dropped conditions will be set to zeros + keep_conditions = einops.rearrange((torch.rand(batch_size) < self.p_cond).float(), 'B -> B 1 1 1') + input_enc = input_enc * keep_conditions.to(input_enc.device) + + x_start = torch.zeros_like(input_enc) + + time = self.flow.generate_time(batch_size=batch_size).to(device=input_enc.device) + sample = self.flow.sample(time=time, x_start=x_start, x_end=target_enc) + + # we want to get a vector field estimate given current state + # at training time, current state is sampled from the conditional path + # the vector field model is also conditioned on input signal + estimator_input = torch.cat([sample, input_enc], dim=-3) + + # Estimate the vector using the neural estimator + estimate, estimate_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=time) + + conditional_vector_field = self.flow.vector_field(time=time, x_start=x_start, x_end=target_enc, point=sample) + + return self.loss(estimate=estimate, target=conditional_vector_field, input_length=input_enc_len) + + # PTL-specific methods + def training_step(self, batch, batch_idx): + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch.get('target_signal', input_signal.clone()) + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, "B T -> B 1 T") + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, "B T -> B 1 T") + + # Calculate the loss + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + + # Logs + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return loss + + def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch.get('target_signal', input_signal.clone()) + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Calculate loss + loss = self._step( + target_signal=target_signal, + input_signal=input_signal, + input_length=input_length, + ) + + # Update metrics + update_metrics = False + if self.max_utts_evaluation_metrics is None: + # Always update if max is not configured + update_metrics = True + # Number of examples to process + num_examples = input_signal.size(0) # batch size + else: + # Check how many examples have been used for metric calculation + first_metric_name = next(iter(self.metrics[tag][dataloader_idx])) + num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples + # Update metrics if some examples were not processed + update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics + # Number of examples to process + num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0)) + + if update_metrics: + # Generate output signal + output_signal, _ = self.forward( + input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples] + ) + + # Update metrics + if hasattr(self, 'metrics') and tag in self.metrics: + # Update metrics for this (tag, dataloader_idx) + for name, metric in self.metrics[tag][dataloader_idx].items(): + metric.update( + preds=output_signal, + target=target_signal[:num_examples, ...], + input_length=input_length[:num_examples], + ) + + # Log global step + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return {f'{tag}_loss': loss} + + class SchroedingerBridgeAudioToAudioModel(AudioToAudioModel): """This models is using a Schrödinger Bridge process to generate an encoded representation of the enhanced signal. diff --git a/nemo/collections/audio/modules/ssl_pretrain_masking.py b/nemo/collections/audio/modules/ssl_pretrain_masking.py new file mode 100644 index 000000000000..ba0722f180d8 --- /dev/null +++ b/nemo/collections/audio/modules/ssl_pretrain_masking.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import einops +import torch + +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import LengthsType, NeuralType, SpectrogramType + +__all__ = ['SSLPretrainWithMaskedPatch'] + + +class SSLPretrainWithMaskedPatch(NeuralModule): + """ + Zeroes out fixed size time patches of the spectrogram. + All samples in batch are guaranteed to have the same amount of masked time steps. + Note that this may be problematic when we do pretraining on a unbalanced dataset. + + For example, say a batch contains two spectrograms of length 87 and 276. + With mask_fraction=0.7 and patch_size=10, we'll obrain mask_patches=7. + Each of the two data will then have 7 patches of 10-frame mask. + + Args: + patch_size (int): up to how many time steps does one patch consist of. + Defaults to 10. + mask_fraction (float): how much fraction in each sample to be masked (number of patches is rounded up). + Range from 0.0 to 1.0. Defaults to 0.7. + """ + + @property + def input_types(self): + """Returns definitions of module input types""" + return { + "input_spec": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output types""" + return {"augmented_spec": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType())} + + def __init__( + self, + patch_size: int = 10, + mask_fraction: float = 0.7, + ): + super().__init__() + self.patch_size = patch_size + if mask_fraction > 1.0 or mask_fraction < 0.0: + raise ValueError('mask_patches cannot be negative') + else: + self.mask_fraction = mask_fraction + + @typecheck() + def forward(self, input_spec, length): + """ + Apply Patched masking on the input_spec. + + + During the training stage, the mask is generated randomly, with + approximately `self.mask_fraction` of the time frames being masked out. + + In the validation stage, the masking pattern is fixed to ensure + consistent evaluation of checkpoints and to prevent overfitting. Note + that the same masking pattern is applied to all data, regardless of + their lengths. On average, approximately `self.mask_fraction` of the + time frames will be masked out. + + """ + augmented_spec = input_spec + + min_len = torch.min(length) + if self.training: + len_fraction = int(min_len * self.mask_fraction) + mask_patches = len_fraction // self.patch_size + int(len_fraction % self.patch_size != 0) + + if min_len < self.patch_size * mask_patches: + mask_patches = min_len // self.patch_size + + for idx, cur_len in enumerate(length.tolist()): + patches = range(cur_len // self.patch_size) + masked_patches = random.sample(patches, mask_patches) + for mp in masked_patches: + augmented_spec[idx, :, :, mp * self.patch_size : (mp + 1) * self.patch_size] = 0.0 + else: + chunk_length = self.patch_size // self.mask_fraction + mask = torch.arange(augmented_spec.size(-1), device=augmented_spec.device) + mask = (mask % chunk_length) >= self.patch_size + mask = einops.rearrange(mask, 'T -> 1 1 1 T').float() + augmented_spec = augmented_spec * mask + + return augmented_spec diff --git a/nemo/collections/audio/parts/submodules/flow.py b/nemo/collections/audio/parts/submodules/flow.py new file mode 100644 index 000000000000..748d4c6c6d3b --- /dev/null +++ b/nemo/collections/audio/parts/submodules/flow.py @@ -0,0 +1,252 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Tuple + +import einops +import torch + +from nemo.collections.common.parts.utils import mask_sequence_tensor +from nemo.utils import logging + + +class ConditionalFlow(ABC): + """ + Abstract class for different conditional flow-matching (CFM) classes + + Time horizon is [time_min, time_max (should be 1)] + + every path is "conditioned" on endpoints of the path + endpoints are just our paired data samples + subclasses need to implement mean, std, and vector_field + + """ + + def __init__(self, time_min: float = 1e-8, time_max: float = 1.0): + self.time_min = time_min + self.time_max = time_max + + @abstractmethod + def mean(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: + """ + Return the mean of p_t(x | x_start, x_end) at time t + """ + pass + + @abstractmethod + def std(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: + """ + Return the standard deviation of p_t(x | x_start, x_end) at time t + """ + pass + + @abstractmethod + def vector_field( + self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor, point: torch.Tensor + ) -> torch.Tensor: + """ + Compute the conditional vector field v_t( point | x_start, x_end) + """ + pass + + @staticmethod + def _broadcast_time(time: torch.Tensor, n_dim: int) -> torch.Tensor: + """ + Broadcast time tensor to the desired number of dimensions + """ + if time.ndim == 1: + target_shape = ' '.join(['B'] + ['1'] * (n_dim - 1)) + time = einops.rearrange(time, f'B -> {target_shape}') + + return time + + def generate_time(self, batch_size: int) -> torch.Tensor: + """ + Randomly sample a batchsize of time_steps from U[0~1] + """ + return torch.clamp(torch.rand((batch_size,)), self.time_min, self.time_max) + + def sample(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: + """ + Generate a sample from p_t(x | x_start, x_end) at time t. + Note that this implementation assumes all path marginals are normally distributed. + """ + time = self._broadcast_time(time, n_dim=x_start.ndim) + + mean = self.mean(time=time, x_start=x_start, x_end=x_end) + std = self.std(time=time, x_start=x_start, x_end=x_end) + return mean + std * torch.randn_like(mean) + + def flow( + self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor, point: torch.Tensor + ) -> torch.Tensor: + """ + Compute the conditional flow phi_t( point | x_start, x_end). + This is an affine flow. + """ + mean = self.mean(time=time, x_start=x_start, x_end=x_end) + std = self.std(time=time, x_start=x_start, x_end=x_end) + return mean + std * (point - x_start) + + +class OptimalTransportFlow(ConditionalFlow): + """The OT-CFM model from [Lipman et at, 2023] + + Every conditional path the following holds: + p_0 = N(x_start, sigma_start) + p_1 = N(x_end, sigma_end), + + mean(x, t) = (time_max - t) * x_start + t * x_end + (linear interpolation between x_start and x_end) + + std(x, t) = (time_max - t) * sigma_start + t * sigma_end + + Every conditional path is optimal transport map from p_0(x_start, x_end) to p_1(x_start, x_end) + Marginal path is not guaranteed to be an optimal transport map from p_0 to p_1 + + To get the OT-CFM model from [Lipman et at, 2023] just pass zeroes for x_start + To get the I-CFM model, set sigma_min=sigma_max + To get the rectified flow model, set sigma_min=sigma_max=0 + + Args: + time_min: minimum time value used in the process + time_max: maximum time value used in the process + sigma_start: the standard deviation of the initial distribution + sigma_end: the standard deviation of the target distribution + """ + + def __init__( + self, time_min: float = 1e-8, time_max: float = 1.0, sigma_start: float = 1.0, sigma_end: float = 1e-4 + ): + super().__init__(time_min=time_min, time_max=time_max) + self.sigma_start = sigma_start + self.sigma_end = sigma_end + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\tsgima_start: %s', self.sigma_start) + logging.debug('\tsigma_end: %s', self.sigma_end) + + def mean(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + return (self.time_max - time) * x_start + time * x_end + + def std(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + return (self.time_max - time) * self.sigma_start + time * self.sigma_end + + def vector_field( + self, + *, + x_start: torch.Tensor, + x_end: torch.Tensor, + time: torch.Tensor, + point: torch.Tensor, + eps: float = 1e-6, + ) -> torch.Tensor: + time = self._broadcast_time(time, n_dim=x_start.ndim) + + if self.sigma_start == self.sigma_end: + return x_end - x_start + + num = self.sigma_end * (point - x_start) - self.sigma_start * (point - x_end) + denom = (1 - time) * self.sigma_start + time * self.sigma_end + return num / (denom + eps) + + +class ConditionalFlowMatchingSampler(ABC): + """ + Abstract class for different sampler to solve the ODE in CFM + + Args: + estimator: the NN-based conditional vector field estimator + num_steps: How many time steps to iterate in the process + time_min: minimum time value used in the process + time_max: maximum time value used in the process + + """ + + def __init__( + self, + estimator: torch.nn.Module, + num_steps: int = 5, + time_min: float = 1e-8, + time_max: float = 1.0, + ): + self.estimator = estimator + self.num_steps = num_steps + self.time_min = time_min + self.time_max = time_max + + @property + def time_step(self): + return (self.time_max - self.time_min) / self.num_steps + + @abstractmethod + def forward( + self, state: torch.Tensor, estimator_condition: torch.Tensor, state_length: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + +class ConditionalFlowMatchingEulerSampler(ConditionalFlowMatchingSampler): + """ + The Euler Sampler for solving the ODE in CFM on a uniform time grid + """ + + def __init__( + self, + estimator: torch.nn.Module, + num_steps: int = 5, + time_min: float = 1e-8, + time_max: float = 1.0, + ): + super().__init__( + estimator=estimator, + num_steps=num_steps, + time_min=time_min, + time_max=time_max, + ) + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.inference_mode() + def forward( + self, state: torch.Tensor, estimator_condition: torch.Tensor, state_length: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + time_steps = torch.linspace(self.time_min, self.time_max, self.num_steps + 1) + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + + for t in time_steps: + time = t * torch.ones(state.shape[0], device=state.device) + + if estimator_condition is None: + estimator_input = state + else: + estimator_input = torch.cat([state, estimator_condition], dim=1) + + vector_field, _ = self.estimator(input=estimator_input, input_length=state_length, condition=time) + + state = state + vector_field * self.time_step + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + + return state, state_length diff --git a/nemo/collections/audio/parts/submodules/transformerunet.py b/nemo/collections/audio/parts/submodules/transformerunet.py new file mode 100644 index 000000000000..b7c14d513bab --- /dev/null +++ b/nemo/collections/audio/parts/submodules/transformerunet.py @@ -0,0 +1,507 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2023 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +from functools import partial +from typing import Dict, Optional + +import einops +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Module + +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import BoolType, FloatType, LengthsType, NeuralType, SpectrogramType +from nemo.utils import logging + +__all__ = ['TransformerUNet'] + + +class LearnedSinusoidalPosEmb(Module): + """The sinusoidal Embedding to encode time conditional information""" + + def __init__(self, dim: int): + super().__init__() + if (dim % 2) != 0: + raise ValueError(f"Input dimension {dim} is not divisible by 2!") + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, t: torch.Tensor) -> torch.Tensor: + """ + Args: + t: input time tensor, shape (B) + + Return: + fouriered: the encoded time conditional embedding, shape (B, D) + """ + t = einops.rearrange(t, 'b -> b 1') + freqs = t * einops.rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + +class ConvPositionEmbed(Module): + """The Convolutional Embedding to encode time information of each frame""" + + def __init__(self, dim: int, kernel_size: int, groups: Optional[int] = None): + super().__init__() + if (kernel_size % 2) == 0: + raise ValueError(f"Kernel size {kernel_size} is divisible by 2!") + + if groups is None: + groups = dim + + self.dw_conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.GELU() + ) + + def forward(self, x, mask=None): + """ + Args: + x: input tensor, shape (B, T, D) + + Return: + out: output tensor with the same shape (B, T, D) + """ + + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(mask, 0.0) + + x = einops.rearrange(x, 'b n c -> b c n') + x = self.dw_conv1d(x) + out = einops.rearrange(x, 'b c n -> b n c') + + if mask is not None: + out = out.masked_fill(mask, 0.0) + + return out + + +class RMSNorm(Module): + """The Root Mean Square Layer Normalization + + References: + - Zhang et al., Root Mean Square Layer Normalization, 2019 + """ + + def __init__(self, dim): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +class AdaptiveRMSNorm(Module): + """ + Adaptive Root Mean Square Layer Normalization given a conditional embedding. + This enables the model to consider the conditional input during normalization. + """ + + def __init__(self, dim: int, cond_dim: Optional[int] = None): + super().__init__() + if cond_dim is None: + cond_dim = dim + self.scale = dim**0.5 + + self.to_gamma = nn.Linear(cond_dim, dim) + self.to_beta = nn.Linear(cond_dim, dim) + + # init adaptive normalization to identity + + nn.init.zeros_(self.to_gamma.weight) + nn.init.ones_(self.to_gamma.bias) + + nn.init.zeros_(self.to_beta.weight) + nn.init.zeros_(self.to_beta.bias) + + def forward(self, x: torch.Tensor, cond: torch.Tensor): + normed = F.normalize(x, dim=-1) * self.scale + + gamma, beta = self.to_gamma(cond), self.to_beta(cond) + gamma = einops.rearrange(gamma, 'B D -> B 1 D') + beta = einops.rearrange(beta, 'B D -> B 1 D') + + return normed * gamma + beta + + +class GEGLU(Module): + """The GeGLU activation implementation""" + + def forward(self, x: torch.Tensor): + x, gate = x.chunk(2, dim=-1) + return F.gelu(gate) * x + + +def get_feedforward_layer(dim: int, mult: int = 4, dropout: float = 0.0): + """ + Return a Feed-Forward layer for the Transformer Layer. + GeGLU activation is used in this FF layer + """ + dim_inner = int(dim * mult * 2 / 3) + return nn.Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim)) + + +class TransformerUNet(NeuralModule): + """ + Implementation of the transformer Encoder Model with U-Net structure used in + VoiceBox and AudioBox + + References: + Le et al., Voicebox: Text-Guided Multilingual Universal Speech Generation at Scale, 2023 + Vyas et al., Audiobox: Unified Audio Generation with Natural Language Prompts, 2023 + """ + + def __init__( + self, + dim: int, + depth: int, + heads: int = 8, + ff_mult: int = 4, + attn_dropout: float = 0.0, + ff_dropout: float = 0.0, + max_positions: int = 6000, + adaptive_rmsnorm: bool = False, + adaptive_rmsnorm_cond_dim_in: Optional[int] = None, + use_unet_skip_connection: bool = True, + skip_connect_scale: Optional[int] = None, + ): + """ + Args: + dim: Embedding dimension + depth: Number of Transformer Encoder Layers + heads: Number of heads in MHA + ff_mult: The multiplier for the feedforward dimension (ff_dim = ff_mult * dim) + attn_dropout: dropout rate for the MHA layer + ff_dropout: droupout rate for the feedforward layer + max_positions: The maximum time length of the input during training and inference + adaptive_rmsnorm: Whether to use AdaptiveRMS layer. + Set to True if the model has a conditional embedding in forward() + adaptive_rms_cond_dim_in: Dimension of the conditional embedding + use_unet_skip_connection: Whether to use U-Net or not + skip_connect_scale: The scale of the U-Net connection. + """ + super().__init__() + if (depth % 2) != 0: + raise ValueError(f"Number of layers {depth} is not divisible by 2!") + self.layers = nn.ModuleList([]) + self.init_alibi(max_positions=max_positions, heads=heads) + + if adaptive_rmsnorm: + rmsnorm_class = partial(AdaptiveRMSNorm, cond_dim=adaptive_rmsnorm_cond_dim_in) + else: + rmsnorm_class = RMSNorm + + if skip_connect_scale is None: + self.skip_connect_scale = 2**-0.5 + else: + self.skip_connect_scale = skip_connect_scale + + for ind in range(depth): + layer = ind + 1 + has_skip = use_unet_skip_connection and layer > (depth // 2) + + self.layers.append( + nn.ModuleList( + [ + nn.Linear(dim * 2, dim) if has_skip else None, + rmsnorm_class(dim=dim), + nn.MultiheadAttention( + embed_dim=dim, + num_heads=heads, + dropout=attn_dropout, + batch_first=True, + ), + rmsnorm_class(dim=dim), + get_feedforward_layer(dim=dim, mult=ff_mult, dropout=ff_dropout), + ] + ) + ) + + self.final_norm = RMSNorm(dim) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tembedding dim: %s', dim) + logging.debug('\tNumber of Layer: %s', depth) + logging.debug('\tfeedforward dim: %s', dim * ff_mult) + logging.debug('\tnumber of heads: %s', heads) + logging.debug('\tDropout rate of MHA: %s', attn_dropout) + logging.debug('\tDropout rate of FF: %s', ff_dropout) + logging.debug('\tnumber of heads: %s', heads) + logging.debug('\tmaximun time length: %s', max_positions) + logging.debug('\tuse AdaptiveRMS: %s', adaptive_rmsnorm) + logging.debug('\tConditional dim: %s', adaptive_rmsnorm_cond_dim_in) + logging.debug('\tUse UNet connection: %s', use_unet_skip_connection) + logging.debug('\tskip connect scale: %s', self.skip_connect_scale) + + def init_alibi( + self, + max_positions: int, + heads: int, + ): + """Initialize the Alibi bias parameters + + References: + - Press et al., Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation, 2021 + """ + + def get_slopes(n): + ratio = 2 ** (-8 / n) + return ratio ** torch.arange(1, n + 1) + + if not math.log2(heads).is_integer(): + logging.warning( + "It is recommend to set number of attention heads to be the power of 2 for the Alibi bias!" + ) + logging.warning(f"Current value of heads: {heads}") + + self.slopes = nn.Parameter(einops.rearrange(get_slopes(heads), "B -> B 1 1")) + + pos_matrix = ( + -1 * torch.abs(torch.arange(max_positions).unsqueeze(0) - torch.arange(max_positions).unsqueeze(1)).float() + ) + pos_matrix = einops.rearrange(pos_matrix, "T1 T2 -> 1 T1 T2") + self.register_buffer('pos_matrix', pos_matrix, persistent=False) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "x": NeuralType(('B', 'T', 'D'), FloatType()), + "key_padding_mask": NeuralType(('B', 'T'), BoolType(), optional=True), + "adaptive_rmsnorm_cond": NeuralType(('B', 'D'), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'T', 'D'), FloatType()), + } + + @typecheck() + def forward(self, x, key_padding_mask: Optional[torch.Tensor] = None, adaptive_rmsnorm_cond=None): + """Forward pass of the model. + + Args: + input: input tensor, shape (B, C, D, T) + key_padding_mask: mask tensor indicating the padding parts, shape (B, T) + adaptive_rmsnorm_cond: conditional input for the model, shape (B, D) + """ + batch_size, seq_len, *_ = x.shape + skip_connects = [] + alibi_bias = self.get_alibi_bias(batch_size=batch_size, seq_len=seq_len) + + rmsnorm_kwargs = dict() + if adaptive_rmsnorm_cond is not None: + rmsnorm_kwargs = dict(cond=adaptive_rmsnorm_cond) + + for skip_combiner, attn_prenorm, attn, ff_prenorm, ff in self.layers: + + if skip_combiner is None: + skip_connects.append(x) + else: + skip_connect = skip_connects.pop() * self.skip_connect_scale + x = torch.cat((x, skip_connect), dim=-1) + x = skip_combiner(x) + + attn_input = attn_prenorm(x, **rmsnorm_kwargs) + if key_padding_mask is not None: + # Since Alibi_bias is a float-type attn_mask, the padding_mask need to be float-type. + float_key_padding_mask = key_padding_mask.float() + float_key_padding_mask = float_key_padding_mask.masked_fill(key_padding_mask, float('-inf')) + else: + float_key_padding_mask = None + + attn_output, _ = attn( + query=attn_input, + key=attn_input, + value=attn_input, + key_padding_mask=float_key_padding_mask, + need_weights=False, + attn_mask=alibi_bias, + ) + x = x + attn_output + + ff_input = ff_prenorm(x, **rmsnorm_kwargs) + x = ff(ff_input) + x + + return self.final_norm(x) + + def get_alibi_bias(self, batch_size: int, seq_len: int): + """ + Return the alibi_bias given batch size and seqence length + """ + pos_matrix = self.pos_matrix[:, :seq_len, :seq_len] + alibi_bias = pos_matrix * self.slopes + alibi_bias = alibi_bias.repeat(batch_size, 1, 1) + + return alibi_bias + + +class SpectrogramTransformerUNet(NeuralModule): + """This model handles complex-valued inputs by stacking real and imaginary components. + Stacked tensor is processed using TransformerUNet and the output is projected to generate real + and imaginary components of the output channels. + + Convolutional Positional Embedding is applied for the input sequence + """ + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + freq_dim: int = 256, + dim: int = 1024, + depth: int = 24, + heads: int = 16, + ff_mult: int = 4, + ff_dropout: float = 0.0, + attn_dropout: float = 0.0, + max_positions: int = 6000, + time_hidden_dim: Optional[int] = None, + conv_pos_embed_kernel_size: int = 31, + conv_pos_embed_groups: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + dim_in = freq_dim * in_channels * 2 + + if time_hidden_dim is None: + time_hidden_dim = dim * 4 + + self.proj_in = nn.Linear(dim_in, dim) + + self.sinu_pos_emb = nn.Sequential(LearnedSinusoidalPosEmb(dim), nn.Linear(dim, time_hidden_dim), nn.SiLU()) + + self.conv_embed = ConvPositionEmbed( + dim=dim, kernel_size=conv_pos_embed_kernel_size, groups=conv_pos_embed_groups + ) + + self.transformerunet = TransformerUNet( + dim=dim, + depth=depth, + heads=heads, + ff_mult=ff_mult, + ff_dropout=ff_dropout, + attn_dropout=attn_dropout, + max_positions=max_positions, + adaptive_rmsnorm=True, + adaptive_rmsnorm_cond_dim_in=time_hidden_dim, + use_unet_skip_connection=True, + ) + + # 2x the frequency dimension as the model operates in the complex-value domain + dim_out = freq_dim * out_channels * 2 + + self.proj_out = nn.Linear(dim, dim_out) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_channels: %s', self.in_channels) + logging.debug('\tout_channels: %s', self.out_channels) + logging.debug('\tInput frequency dimension: %s', freq_dim) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + "condition": NeuralType(('B',), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @staticmethod + def _get_key_padding_mask(input_length: torch.Tensor, max_length: int): + """ + Return the self_attention masking according to the input length. + 0 indicates the frame is in the valid range, while 1 indicates the frame is a padding frame. + Args: + input_length: shape (B) + max_length (int): The maximum length of the input sequence + + return: + key_padding_mask: shape (B, T) + """ + key_padding_mask = torch.arange(max_length).expand(len(input_length), max_length).to(input_length.device) + key_padding_mask = key_padding_mask >= input_length.unsqueeze(1) + return key_padding_mask + + @typecheck() + def forward(self, input, input_length=None, condition=None): + """Forward pass of the model. + + Args: + input: input tensor, shape (B, C, D, T) + input_length: length of the valid time steps for each example in the batch, shape (B,) + condition: scalar condition (time) for the model, will be embedded using `self.time_embedding` + """ + # Stack real and imaginary components + B, C_in, D, T = input.shape + if C_in != self.in_channels: + raise RuntimeError(f'Unexpected input channel size {C_in}, expected {self.in_channels}') + + input_real_imag = torch.stack([input.real, input.imag], dim=2) + input = einops.rearrange(input_real_imag, 'B C RI D T -> B T (C RI D)') + + x = self.proj_in(input) + key_padding_mask = self._get_key_padding_mask(input_length, max_length=T) + x = self.conv_embed(x, mask=key_padding_mask) + x + + if condition is None: + raise NotImplementedError + + time_emb = self.sinu_pos_emb(condition) + + x = self.transformerunet(x=x, key_padding_mask=key_padding_mask, adaptive_rmsnorm_cond=time_emb) + + output = self.proj_out(x) + output = einops.rearrange(output, "B T (C RI D) -> B C D T RI", C=self.out_channels, RI=2, D=D) + output = torch.view_as_complex(output.contiguous()) + + return output, input_length diff --git a/nemo/collections/audio/parts/utils/callbacks.py b/nemo/collections/audio/parts/utils/callbacks.py new file mode 100644 index 000000000000..093d5a11f419 --- /dev/null +++ b/nemo/collections/audio/parts/utils/callbacks.py @@ -0,0 +1,177 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Type + +import einops +import torch +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.loggers.wandb import WandbLogger + +from nemo.utils import logging +from nemo.utils.decorators import experimental + +HAVE_WANDB = True +try: + import wandb +except ModuleNotFoundError: + HAVE_WANDB = False + + +def _get_logger(loggers: List[Logger], logger_type: Type[Logger]): + for logger in loggers: + if isinstance(logger, logger_type): + if hasattr(logger, "experiment"): + return logger.experiment + else: + return logger + raise ValueError(f"Could not find {logger_type} logger in {loggers}.") + + +@experimental +class SpeechEnhancementLoggingCallback(Callback): + """ + Callback which can log artifacts (eg. model predictions, graphs) to local disk, Tensorboard, and/or WandB. + + Args: + data_loader: Data to log artifacts for. + output_dir: Optional local directory. If provided, artifacts will be saved in output_dir. + loggers: Optional list of loggers to use if logging to tensorboard or wandb. + log_tensorboard: Whether to log artifacts to tensorboard. + log_wandb: Whether to log artifacts to WandB. + """ + + def __init__( + self, + data_loader, + data_loader_idx: int, + loggers: Optional[List[Logger]] = None, + log_tensorboard: bool = False, + log_wandb: bool = False, + sample_rate: int = 16000, + max_utts: Optional[int] = None, + ): + self.data_loader = data_loader + self.data_loader_idx = data_loader_idx + self.loggers = loggers if loggers else [] + self.log_tensorboard = log_tensorboard + self.log_wandb = log_wandb + self.sample_rate = sample_rate + self.max_utts = max_utts + + if log_tensorboard: + logging.info('Creating tensorboard logger') + self.tensorboard_logger = _get_logger(self.loggers, TensorBoardLogger) + else: + logging.debug('Not using tensorbord logger') + self.tensorboard_logger = None + + if log_wandb: + if not HAVE_WANDB: + raise ValueError("Wandb not installed.") + logging.info('Creating wandb logger') + self.wandb_logger = _get_logger(self.loggers, WandbLogger) + else: + logging.debug('Not using wandb logger') + self.wandb_logger = None + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tlog_tensorboard: %s', self.log_tensorboard) + logging.debug('\tlog_wandb: %s', self.log_wandb) + + def _log_audio(self, audios: torch.Tensor, lengths: torch.Tensor, step: int, label: str = "input"): + + num_utts = audios.size(0) + for audio_idx in range(num_utts): + length = lengths[audio_idx] + if self.tensorboard_logger: + self.tensorboard_logger.add_audio( + tag=f"{label}_{audio_idx}", + snd_tensor=audios[audio_idx, :length], + global_step=step, + sample_rate=self.sample_rate, + ) + + if self.wandb_logger: + wandb_audio = ( + wandb.Audio(audios[audio_idx], sample_rate=self.sample_rate, caption=f"{label}_{audio_idx}"), + ) + self.wandb_logger.log({f"{label}_{audio_idx}": wandb_audio}) + + def on_validation_epoch_end(self, trainer: Trainer, model: LightningModule): + """Log artifacts at the end of an epoch.""" + epoch = 1 + model.current_epoch + output_signal_list = [] + output_length_list = [] + num_examples_uploaded = 0 + + logging.info(f"Logging processed speech for validation dataset {self.data_loader_idx}...") + for batch in self.data_loader: + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch.get('target_signal', input_signal.clone()) + else: + input_signal, input_length, target_signal, _ = batch + + if self.max_utts is None: + num_examples = input_signal.size(0) # batch size + do_upload = True + else: + do_upload = num_examples_uploaded < self.max_utts + num_examples = min(self.max_utts - num_examples_uploaded, input_signal.size(0)) + num_examples_uploaded += num_examples + + if do_upload: + # Only pick the required numbers of speech to the logger + input_signal = input_signal[:num_examples, ...] + target_signal = target_signal[:num_examples, ...] + input_length = input_length[:num_examples] + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + input_signal = input_signal.to(model.device) + input_length = input_length.to(model.device) + + output_signal, output_length = model(input_signal=input_signal, input_length=input_length) + output_signal_list.append(output_signal.to(target_signal.device)) + output_length_list.append(output_length.to(target_signal.device)) + + if len(output_signal_list) == 0: + logging.debug('List are empty, no artifacts to log at epoch %d.', epoch) + return + + output_signals = torch.concat(output_signal_list, dim=0) + output_lengths = torch.concat(output_length_list, dim=0) + if output_signals.size(1) != 1: + logging.error( + f"Currently only supports single-channel audio! Current output shape: {output_signals.shape}" + ) + raise NotImplementedError + + output_signals = einops.rearrange(output_signals, "B 1 T -> B T") + + self._log_audio( + audios=output_signals, + lengths=output_lengths, + step=model.global_step, + label=f"dataloader_{self.data_loader_idx}_processed", + ) diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index 2a4b71a18880..3c5ced5d4018 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -24,7 +24,7 @@ import lhotse.serialization import soundfile from cytoolz import groupby -from lhotse import AudioSource, Recording, SupervisionSegment +from lhotse import AudioSource, MonoCut, Recording, SupervisionSegment from lhotse.audio.backend import LibsndfileBackend from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed @@ -112,11 +112,9 @@ def __iter__(self) -> Generator[Cut, None, None]: audio_path = get_full_path(str(data.pop("audio_filepath")), str(self.path)) duration = data.pop("duration") offset = data.pop("offset", None) - recording = self._create_recording(audio_path, duration, data.pop("sampling_rate", None)) - cut = recording.to_cut() - if offset is not None: - cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) - cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}" + cut = self._create_cut( + audio_path=audio_path, offset=offset, duration=duration, sampling_rate=data.pop("sampling_rate", None) + ) # Note that start=0 and not start=offset because supervision's start if relative to the # start of the cut; and cut.start is already set to offset cut.supervisions.append( @@ -140,6 +138,42 @@ def __len__(self) -> int: def __add__(self, other): return LazyIteratorChain(self, other) + def _create_cut( + self, + audio_path: str, + offset: float, + duration: float, + sampling_rate: int | None = None, + ) -> Cut: + if not self.metadata_only: + recording = self._create_recording(audio_path, duration, sampling_rate) + cut = recording.to_cut() + if offset is not None: + cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) + cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}" + else: + # Only metadata requested. + # We'll provide accurate metadata for Cut but inaccurate metadata for Recording to avoid + # incurring IO penalty (note that Lhotse manifests contain more information than + # NeMo manifests, so for actual dataloading we have to fill it using the audio file). + sr = ifnone(sampling_rate, 16000) # fake sampling rate + offset = ifnone(offset, 0.0) + cut = MonoCut( + id=audio_path, + start=offset, + duration=duration, + channel=0, + supervisions=[], + recording=Recording( + id=audio_path, + sources=[AudioSource(type="dummy", channels=[0], source="")], + sampling_rate=sr, + duration=offset + duration, + num_samples=compute_num_samples(offset + duration, sr), + ), + ) + return cut + def _create_recording( self, audio_path: str, @@ -156,15 +190,6 @@ def _create_recording( duration=duration, channel_ids=[0], ) - elif self.metadata_only: - return Recording( - id=audio_path, - sources=[AudioSource(type="file", channels=[0], source=audio_path)], - sampling_rate=-1, - num_samples=-1, - duration=duration, - channel_ids=[0], - ) else: return Recording.from_file(audio_path) diff --git a/nemo/collections/common/parts/preprocessing/cleaners.py b/nemo/collections/common/parts/preprocessing/cleaners.py index 40c80115786a..0697abe8792e 100644 --- a/nemo/collections/common/parts/preprocessing/cleaners.py +++ b/nemo/collections/common/parts/preprocessing/cleaners.py @@ -14,7 +14,6 @@ import re -import inflect from text_unidecode import unidecode from nemo.utils import logging @@ -139,7 +138,14 @@ ] -inflect = inflect.engine() +from functools import cache + + +@cache +def inflect_engine(): + import inflect + + return inflect.engine() def clean_text(string, table, punctuation_to_replace, abbreviation_version=None): @@ -194,11 +200,12 @@ def reset(self): self.currency = None def format_final_number(self, whole_num, decimal): + inflect = inflect_engine() if self.currency: return_string = inflect.number_to_words(whole_num) return_string += " dollar" if whole_num == 1 else " dollars" if decimal: - return_string += " and " + inflect.number_to_words(decimal) + return_string += " and " + inflect_engine().number_to_words(decimal) return_string += " cent" if whole_num == decimal else " cents" self.reset() return return_string @@ -210,11 +217,12 @@ def format_final_number(self, whole_num, decimal): else: # Check if there are non-numbers def convert_to_word(match): - return " " + inflect.number_to_words(match.group(0)) + " " + return " " + inflect_engine().number_to_words(match.group(0)) + " " return re.sub(r'[0-9,]+', convert_to_word, whole_num) def clean(self, match): + inflect = inflect_engine() ws = match.group(2) number = match.group(3) _proceeding_symbol = match.group(7) diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index 75783815548a..e08f7d710183 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -159,3 +159,57 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): raise ValueError('Can only mask tensors of shape B x L, B x D x L and B x D1 x D2 x L') return tensor * mask + + +class ClampActivation(nn.Module): + + def __init__(self, min_value: float = -1.0, max_value: float = 1.0): + super().__init__() + self.min_value = min_value + self.max_value = max_value + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.clamp(input, min=self.min_value, max=self.max_value) + + +@torch.jit.script +def snake(x: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-9) -> torch.Tensor: + """ + equation for snake activation function: x + (alpha + eps)^-1 * sin(alpha * x)^2 + """ + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + eps).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake(nn.Module): + """ + Snake activation function introduced in 'https://arxiv.org/abs/2006.08195' + """ + + def __init__(self, channels: int): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return snake(x, self.alpha) + + +class HalfSnake(nn.Module): + """ + Activation which applies snake to the first half of input elements and leaky relu to the second half. + """ + + def __init__(self, channels: int): + super().__init__() + self.snake_channels = channels // 2 + self.snake_act = Snake(self.snake_channels) + self.lrelu = torch.nn.LeakyReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + snake_out = self.snake_act(x[:, : self.snake_channels, :]) + lrelu_out = self.lrelu(x[:, self.snake_channels :, :]) + out = torch.cat([snake_out, lrelu_out], dim=1) + return out diff --git a/nemo/collections/common/tokenizers/en_ja_tokenizers.py b/nemo/collections/common/tokenizers/en_ja_tokenizers.py index cf58130834e9..c72ae1853deb 100644 --- a/nemo/collections/common/tokenizers/en_ja_tokenizers.py +++ b/nemo/collections/common/tokenizers/en_ja_tokenizers.py @@ -14,9 +14,6 @@ import re from typing import List -from pangu import spacing -from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer - try: import ipadic import MeCab @@ -36,6 +33,8 @@ class EnJaProcessor: """ def __init__(self, lang_id: str): + from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer + self.lang_id = lang_id self.moses_tokenizer = MosesTokenizer(lang=lang_id) self.moses_detokenizer = MosesDetokenizer(lang=lang_id) @@ -81,6 +80,8 @@ def __init__(self): self.mecab_tokenizer = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati") def detokenize(self, text: List[str]) -> str: + from pangu import spacing + RE_WS_IN_FW = re.compile( r'([\u2018\u2019\u201c\u201d\u2e80-\u312f\u3200-\u32ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff\uff00-\uffef])\s+(?=[\u2018\u2019\u201c\u201d\u2e80-\u312f\u3200-\u32ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff\uff00-\uffef])' ) diff --git a/nemo/collections/common/tokenizers/indic_tokenizers.py b/nemo/collections/common/tokenizers/indic_tokenizers.py index 3b9192c8885b..eaf3aa5c7b64 100644 --- a/nemo/collections/common/tokenizers/indic_tokenizers.py +++ b/nemo/collections/common/tokenizers/indic_tokenizers.py @@ -14,8 +14,6 @@ from typing import List -from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer - class IndicProcessor: """ @@ -26,6 +24,8 @@ class IndicProcessor: def __init__(self, lang_id: str): if lang_id != 'hi': raise NotImplementedError + from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer + self.moses_tokenizer = MosesTokenizer(lang=lang_id) self.moses_detokenizer = MosesDetokenizer(lang=lang_id) self.normalizer = MosesPunctNormalizer(lang=lang_id) diff --git a/nemo/collections/common/tokenizers/moses_tokenizers.py b/nemo/collections/common/tokenizers/moses_tokenizers.py index 27e91e6c5262..717427090dd2 100644 --- a/nemo/collections/common/tokenizers/moses_tokenizers.py +++ b/nemo/collections/common/tokenizers/moses_tokenizers.py @@ -14,8 +14,6 @@ from typing import List -from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer - class MosesProcessor: """ @@ -23,6 +21,8 @@ class MosesProcessor: """ def __init__(self, lang_id: str): + from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer + self.moses_tokenizer = MosesTokenizer(lang=lang_id) self.moses_detokenizer = MosesDetokenizer(lang=lang_id) self.normalizer = MosesPunctNormalizer(lang=lang_id) diff --git a/nemo/collections/llm/README.md b/nemo/collections/llm/README.md new file mode 100644 index 000000000000..3e25f84a0c54 --- /dev/null +++ b/nemo/collections/llm/README.md @@ -0,0 +1,11 @@ +NeMo LLM Collection +=================== + +The NeMo LLM Collection introduces NeMo 2.0, a redesign that enhances the user experience by adopting a more PyTorch Lightning-like approach. This redesign aims to simplify NeMo and make it more modular. + +The following models are currently reimplemented in 2.0 as part of this collection: +- **GPT** +- **LLaMA** +- **Mixtral** + +For detailed tutorials and documentation on NeMo 2.0, refer to the [docs](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo_2.0/index.html). diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 7b2b38e50bc3..52c353ba16d7 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -39,6 +39,9 @@ Llama2Config70B, Llama3Config8B, Llama3Config70B, + Llama31Config8B, + Llama31Config70B, + Llama31Config405B, LlamaConfig, LlamaModel, MaskedTokenLossReduction, @@ -48,6 +51,27 @@ MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel, + Nemotron3Config4B, + Nemotron3Config8B, + Nemotron4Config15B, + Nemotron4Config22B, + Nemotron4Config340B, + NemotronConfig, + NemotronModel, + Qwen2Config, + Qwen2Config1P5B, + Qwen2Config7B, + Qwen2Config72B, + Qwen2Config500M, + Qwen2Model, + Starcoder2Config, + Starcoder2Config3B, + Starcoder2Config7B, + Starcoder2Config15B, + Starcoder2Model, + StarcoderConfig, + StarcoderConfig15B, + StarcoderModel, gpt_data_step, gpt_forward_step, ) @@ -73,12 +97,25 @@ "MixtralConfig8x7B", "MixtralConfig8x22B", "MixtralModel", + "Starcoder2Config15B", + "Starcoder2Config", + "Starcoder2Model", + "NemotronModel", + "Nemotron3Config4B", + "Nemotron3Config8B", + "Nemotron4Config15B", + "Nemotron4Config22B", + "Nemotron4Config340B", + "NemotronConfig", "LlamaConfig", "Llama2Config7B", "Llama2Config13B", "Llama2Config70B", "Llama3Config8B", "Llama3Config70B", + "Llama31Config8B", + "Llama31Config70B", + "Llama31Config405B", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", @@ -97,6 +134,12 @@ "ChatGLM2Config6B", "ChatGLM3Config6B", "ChatGLMModel", + "Qwen2Model", + "Qwen2Config7B", + "Qwen2Config", + "Qwen2Config500M", + "Qwen2Config1P5B", + "Qwen2Config72B", "PreTrainingDataModule", "FineTuningDataModule", "SquadDataModule", diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 46d94d26b03b..8bead26e653e 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -8,25 +8,10 @@ from typing_extensions import Annotated from nemo.collections.llm.utils import Config, task -from nemo.deploy import DeployPyTriton from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging -trt_llm_supported = True -try: - from nemo.export.tensorrt_llm import TensorRTLLM -except ImportError as error: - logging.warning(f"TensorRTLLM could not be imported from nemo.export: {error}") - trt_llm_supported = False - -uvicorn_supported = True -try: - import uvicorn -except ImportError as error: - logging.warning(f"uvicorn could not be imported: {error}") - uvicorn_supported = False - TokenizerType = Any @@ -253,6 +238,8 @@ def get_trtllm_deployable( max_batch_size, dtype, ): + from nemo.export.tensorrt_llm import TensorRTLLM + if triton_model_repository is None: trt_llm_path = "/tmp/trt_llm_model_dir/" Path(trt_llm_path).mkdir(parents=True, exist_ok=True) @@ -274,8 +261,6 @@ def get_trtllm_deployable( if nemo_checkpoint is not None and model_type is None: raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") - if not trt_llm_supported: - raise ValueError("TensorRT-LLM engine is not supported in this environment.") trt_llm_exporter = TensorRTLLM( model_dir=trt_llm_path, load_model=(nemo_checkpoint is None), @@ -334,6 +319,8 @@ def deploy( rest_service_port: int = 8000, openai_format_response: bool = False, ): + from nemo.deploy import DeployPyTriton + if start_rest_service: if triton_port == rest_service_port: logging.error("REST service port and Triton server port cannot use the same port.") @@ -370,6 +357,13 @@ def deploy( logging.error("Error message has occurred during deploy function. Error message: " + str(error)) return + uvicorn_supported = True + try: + import uvicorn + except ImportError as error: + logging.warning(f"uvicorn could not be imported: {error}") + uvicorn_supported = False + try: logging.info("Model serving on Triton is will be started.") if start_rest_service and uvicorn_supported: diff --git a/nemo/collections/llm/fn/activation.py b/nemo/collections/llm/fn/activation.py index 89b5ba93f0f6..fb638ee31f86 100644 --- a/nemo/collections/llm/fn/activation.py +++ b/nemo/collections/llm/fn/activation.py @@ -9,3 +9,9 @@ def gelu_impl(x): def openai_gelu(x): return gelu_impl(x) + + +@torch.jit.script +def squared_relu(x): + """Squared ReLU activation function.""" + return torch.pow(torch.nn.functional.relu(x), 2) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index e2d940e02d32..7de5d5b5b5f4 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -27,6 +27,9 @@ Llama2Config70B, Llama3Config8B, Llama3Config70B, + Llama31Config8B, + Llama31Config70B, + Llama31Config405B, LlamaConfig, LlamaModel, ) @@ -37,6 +40,31 @@ MixtralConfig8x22B, MixtralModel, ) +from nemo.collections.llm.gpt.model.nemotron import ( + Nemotron3Config4B, + Nemotron3Config8B, + Nemotron4Config15B, + Nemotron4Config22B, + Nemotron4Config340B, + NemotronConfig, + NemotronModel, +) +from nemo.collections.llm.gpt.model.qwen2 import ( + Qwen2Config, + Qwen2Config1P5B, + Qwen2Config7B, + Qwen2Config72B, + Qwen2Config500M, + Qwen2Model, +) +from nemo.collections.llm.gpt.model.starcoder import StarcoderConfig, StarcoderConfig15B, StarcoderModel +from nemo.collections.llm.gpt.model.starcoder2 import ( + Starcoder2Config, + Starcoder2Config3B, + Starcoder2Config7B, + Starcoder2Config15B, + Starcoder2Model, +) __all__ = [ "GPTConfig", @@ -45,13 +73,32 @@ "MistralModel", "MixtralConfig8x3B", "MixtralConfig8x7B", + "MixtralConfig8x22B", "MixtralModel", + "Starcoder2Config", + "Starcoder2Model", + "Starcoder2Config15B", + "Starcoder2Config7B", + "Starcoder2Config3B", + "StarcoderConfig", + "StarcoderConfig15B", + "StarcoderModel", "LlamaConfig", "Llama2Config7B", "Llama2Config13B", "Llama2Config70B", "Llama3Config8B", "Llama3Config70B", + "Llama31Config8B", + "Llama31Config70B", + "Llama31Config405B", + "NemotronConfig", + "Nemotron3Config4B", + "Nemotron3Config8B", + "Nemotron4Config15B", + "Nemotron4Config22B", + "Nemotron4Config340B", + "NemotronModel", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", @@ -70,6 +117,12 @@ "ChatGLM2Config6B", "ChatGLM3Config6B", "ChatGLMModel", + "Qwen2Config", + "Qwen2Config500M", + "Qwen2Config1P5B", + "Qwen2Config7B", + "Qwen2Config72B", + "Qwen2Model", "MaskedTokenLossReduction", "gpt_data_step", "gpt_forward_step", diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 2badfa2b1915..c108415a085e 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -13,6 +13,7 @@ from nemo.lightning import get_vocab_size, io from nemo.lightning.megatron_parallel import MaskedTokenLossReduction from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule +from nemo.utils import logging HAVE_TE = True try: @@ -131,10 +132,19 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel": if not isinstance(transformer_layer_spec, ModuleSpec): transformer_layer_spec = transformer_layer_spec(self) + if hasattr(self, 'vocab_size'): + vocab_size = self.vocab_size + logging.info( + f"Use preset vocab_size: {vocab_size}, original vocab_size: {tokenizer.vocab_size}, dummy tokens:" + f" {vocab_size - tokenizer.vocab_size}." + ) + else: + vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by) + return MCoreGPTModel( self, transformer_layer_spec=transformer_layer_spec, - vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), + vocab_size=vocab_size, max_sequence_length=self.seq_length, fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, parallel_output=self.parallel_output, diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index 425170c07707..4f7dd4d37a90 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Annotated, Callable, Optional @@ -9,6 +10,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown +from nemo.utils import logging if TYPE_CHECKING: from transformers import LlamaConfig as HFLlamaConfig @@ -66,13 +68,13 @@ class Llama3Config(GPTConfig): num_query_groups: int = 8 hidden_dropout: float = 0.0 attention_dropout: float = 0.0 - normalization = "RMSNorm" + normalization: str = "RMSNorm" init_method_std: float = 0.01 layernorm_epsilon: float = 1.0e-05 add_bias_linear: bool = False activation_func: Callable = F.silu gated_linear_unit: bool = True - apply_query_key_layer_scaling: bool = True + apply_query_key_layer_scaling: bool = False # Fusions bias_activation_fusion: bool = True masked_softmax_fusion: bool = True @@ -80,10 +82,31 @@ class Llama3Config(GPTConfig): bias_dropout_fusion: bool = True apply_rope_fusion: bool = True share_embeddings_and_output_weights: bool = False - position_embedding_type = "rope" + position_embedding_type: str = "rope" rotary_percent: float = 1.0 +@dataclass +class Llama31Config(Llama3Config): + scale_factor: int = 8 + low_freq_factor: int = 1 + high_freq_factor: int = 4 + old_context_len: int = 8192 + init_method_std: float = 0.02 + + def configure_model(self, tokenizer) -> "MCoreGPTModel": + model = super().configure_model(tokenizer) + # Apply rope scaling for Llama3.1 model + model.rotary_pos_emb.inv_freq = apply_rope_scaling( + model.rotary_pos_emb.inv_freq, + factor=self.scale_factor, + low_freq_factor=self.low_freq_factor, + high_freq_factor=self.high_freq_factor, + old_context_len=self.old_context_len, + ) + return model + + @dataclass class Llama3Config8B(Llama3Config): rotary_base: int = 500_000 @@ -106,6 +129,38 @@ class Llama3Config70B(Llama3Config): make_vocab_size_divisible_by: int = 128 +@dataclass +class Llama31Config8B(Llama31Config): + rotary_base: int = 500_000 + seq_length: int = 131072 + num_layers: int = 32 + hidden_size: int = 4096 + ffn_hidden_size: int = 14336 + num_attention_heads: int = 32 + + +@dataclass +class Llama31Config70B(Llama31Config): + rotary_base: int = 500_000 + seq_length: int = 131072 + num_layers: int = 80 + hidden_size: int = 8192 + ffn_hidden_size: int = 28672 + num_attention_heads: int = 64 + make_vocab_size_divisible_by: int = 128 + + +@dataclass +class Llama31Config405B(Llama31Config): + rotary_base: int = 500_000 + seq_length: int = 131072 + num_layers: int = 126 + hidden_size: int = 16384 + ffn_hidden_size: int = 53248 + num_attention_heads: int = 128 + make_vocab_size_divisible_by: int = 128 + + @dataclass class CodeLlamaConfig7B(Llama2Config7B): rotary_base: int = 1_000_000 @@ -365,6 +420,33 @@ def _export_linear_fc1(linear_fc1): return gate_proj, up_proj +def apply_rope_scaling( + inv_freq, + factor: int = 8, + low_freq_factor: int = 1, + high_freq_factor: int = 4, + old_context_len: int = 8192, +): + logging.info( + f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, high_freq_factor={high_freq_factor}, old_context_len={old_context_len}." + ) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama + + __all__ = [ "LlamaConfig", "Llama2Config7B", @@ -372,6 +454,9 @@ def _export_linear_fc1(linear_fc1): "Llama2Config70B", "Llama3Config8B", "Llama3Config70B", + "Llama31Config8B", + "Llama31Config70B", + "Llama31Config405B", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py new file mode 100644 index 000000000000..d946e5f48cce --- /dev/null +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -0,0 +1,345 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch +from torch import nn + +from nemo.collections.llm.fn.activation import squared_relu +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown + +if TYPE_CHECKING: + from transformers import NemotronConfig as HFNemotronConfig + from transformers import NemotronForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +@dataclass +class NemotronConfig(GPTConfig): + # configs that are common across model sizes + normalization: str = "LayerNorm" + activation_func: Callable = squared_relu + position_embedding_type: str = "rope" + share_embeddings_and_output_weights: bool = False + add_bias_linear: bool = False + + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + apply_query_key_layer_scaling: bool = True + rotary_percent: float = 0.5 + masked_softmax_fusion: bool = True + persist_layer_norm: bool = True + bias_dropout_add_fusion: bool = False + layernorm_zero_centered_gamma: bool = True + + # Nemotron3Config4B as default configs + num_layers: int = 32 + seq_length: int = 4096 + hidden_size: int = 3072 + ffn_hidden_size: int = 9216 + num_attention_heads: int = 24 + num_query_groups: Optional[int] = 8 + kv_channels: Optional[int] = 128 + init_method_std: float = 0.0134 + + +@dataclass +class Nemotron3Config4B(NemotronConfig): + num_layers: int = 32 + seq_length: int = 4096 + hidden_size: int = 3072 + ffn_hidden_size: int = 9216 + num_attention_heads: int = 24 + num_query_groups: int = 8 + kv_channels: Optional[int] = 128 + init_method_std: float = 0.0134 + + +@dataclass +class Nemotron3Config8B(NemotronConfig): + num_layers: int = 32 + seq_length: int = 4096 + hidden_size: int = 4096 + ffn_hidden_size: int = 16384 + num_attention_heads: int = 32 + num_query_groups: Optional[int] = None + kv_channels: Optional[int] = None + init_method_std: float = 0.010 + + +@dataclass +class Nemotron4Config15B(NemotronConfig): + num_layers: int = 32 + seq_length: int = 4096 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_attention_heads: int = 48 + num_query_groups: Optional[int] = 8 + kv_channels: Optional[int] = None + init_method_std: float = 0.0134 + + +@dataclass +class Nemotron4Config22B(NemotronConfig): + num_layers: int = 40 + seq_length: int = 4096 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_attention_heads: int = 48 + num_query_groups: Optional[int] = None + kv_channels: Optional[int] = None + init_method_std: float = 0.008 + + +@dataclass +class Nemotron4Config340B(NemotronConfig): + num_layers: int = 96 + seq_length: int = 4096 + hidden_size: int = 18432 + ffn_hidden_size: int = 73728 + num_attention_heads: int = 96 + num_query_groups: Optional[int] = 8 + kv_channels: Optional[int] = None + init_method_std: float = 0.0063 + + +class NemotronModel(GPTModel): + def __init__( + self, + config: Annotated[Optional[NemotronConfig], Config[NemotronConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(config or NemotronConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + + +@io.model_importer(NemotronModel, "hf") +class HFNemotronImporter(io.ModelConnector["NemotronForCausalLM", NemotronModel]): + def init(self) -> NemotronModel: + return NemotronModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import NemotronForCausalLM + + source = NemotronForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Nemotron model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.up_proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.input_layernorm.bias": "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.layers.*.post_attention_layernorm.bias": "decoder.layers.*.mlp.linear_fc1.layer_norm_bias", + "model.norm.weight": "decoder.final_layernorm.weight", + "model.norm.bias": "decoder.final_layernorm.bias", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv]) + + @property + def tokenizer(self) -> "AutoTokenizer": + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self)) + + @property + def config(self) -> NemotronConfig: + from transformers import NemotronConfig as HFNemotronConfig + + source = HFNemotronConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = NemotronConfig( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + seq_length=source.max_position_embeddings, + layernorm_epsilon=source.norm_eps, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + rotary_percent=source.partial_rotary_factor, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.model_exporter(NemotronModel, "hf") +class HFNemotronExporter(io.ModelConnector[NemotronModel, "NemotronForCausalLM"]): + def init(self) -> "NemotronForCausalLM": + return NemotronForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc1.weight": "model.layers.*.mlp.up_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.layers.*.input_layernorm.bias", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.layers.*.post_attention_layernorm.bias", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.final_layernorm.bias": "model.norm.bias", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv]) + + @property + def tokenizer(self): + return io.load_context(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFNemotronConfig": + from transformers import NemotronConfig as HFNemotronConfig + + source: NemotronConfig = io.load_context(str(self)).model.config + + return HFNemotronConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + head_dim=( + source.kv_channels + if source.kv_channels is not None + else source.hidden_size // source.num_attention_heads + ), + tie_word_embeddings=source.share_embeddings_and_output_weights, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + partial_rotary_factor=source.rotary_percent, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +__all__ = [ + "NemotronConfig", + "Nemotron3Config4B", + "Nemotron3Config8B", + "Nemotron4Config15B", + "Nemotron4Config22B", + "Nemotron4Config340B", + "NemotronModel", +] diff --git a/nemo/collections/llm/gpt/model/qwen2.py b/nemo/collections/llm/gpt/model/qwen2.py new file mode 100644 index 000000000000..eb67dd9d4f0d --- /dev/null +++ b/nemo/collections/llm/gpt/model/qwen2.py @@ -0,0 +1,392 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown + +if TYPE_CHECKING: + from transformers import AutoModelForCausalLM + from transformers import Qwen2Config as HFQwen2Config + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +@dataclass +class Qwen2Config(GPTConfig): + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = True + seq_length: int = 4096 + init_method_std: int = 0.02 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + vocab_size: int = 151936 + share_embeddings_and_output_weights: Optional[bool] = False + layernorm_epsilon: float = 1e-6 + rotary_base: float = 1000000.0 + position_embedding_type: str = "rope" + apply_query_key_layer_scaling: bool = True + + +@dataclass +class Qwen2Config500M(Qwen2Config): + num_layers: int = 24 + hidden_size: int = 896 + num_attention_heads: int = 14 + num_query_groups: int = 2 + ffn_hidden_size: int = 4864 + + +@dataclass +class Qwen2Config1P5B(Qwen2Config): + num_layers: int = 28 + hidden_size: int = 1536 + num_attention_heads: int = 12 + num_query_groups: int = 2 + ffn_hidden_size: int = 8960 + + +@dataclass +class Qwen2Config7B(Qwen2Config): + num_layers: int = 28 + hidden_size: int = 3584 + num_attention_heads: int = 28 + num_query_groups: int = 4 + ffn_hidden_size: int = 18944 + vocab_size: int = 152064 + + +@dataclass +class Qwen2Config72B(Qwen2Config): + num_layers: int = 80 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 29568 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-5 + vocab_size: int = 152064 + + +class Qwen2Model(GPTModel): + def __init__( + self, + config: Annotated[Optional[Qwen2Config], Config[Qwen2Config]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(config or Qwen2Config(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + + +@io.model_importer(Qwen2Model, "hf") +class HFQwen2Importer(io.ModelConnector["AutoModelForCausalLM", Qwen2Model]): + def init(self) -> Qwen2Model: + return Qwen2Model(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import AutoModelForCausalLM + + source = AutoModelForCausalLM.from_pretrained(str(self), trust_remote_code=True) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Qwen model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms( + source, target, mapping=mapping, transforms=[_import_qkv, _import_qkv_bias, _import_linear_fc1] + ) + + @property + def tokenizer(self) -> "AutoTokenizer": + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self), trust_remote_code=True) + + @property + def config(self) -> Qwen2Config: + from transformers import AutoConfig as HFAutoConfig + + source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True) + + output = Qwen2Config( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + num_query_groups=source.num_key_value_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=128, + rotary_base=source.rope_theta, + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.model_exporter(Qwen2Model, "hf") +class HFQwen2Exporter(io.ModelConnector[Qwen2Model, "AutoModelForCausalLM"]): + def init(self) -> "AutoModelForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms( + source, target, mapping=mapping, transforms=[_export_qkv, _export_qkv_bias, _export_linear_fc1] + ) + + @property + def tokenizer(self): + return io.load_context(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFQwen2Config": + from transformers import Qwen2Config as HFQwen2Config + + source: Qwen2Config = io.load_context(str(self)).model.config + + return HFQwen2Config( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + rms_norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + vocab_size=getattr(source, 'vocab_size', self.tokenizer.vocab_size), + sliding_window=source.seq_length, + tie_word_embeddings=False, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.bias", + "model.layers.*.self_attn.k_proj.bias", + "model.layers.*.self_attn.v_proj.bias", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.bias", +) +def _import_qkv_bias(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + new_q_tensor_shape = (head_num, head_size) + new_kv_tensor_shape = (num_query_groups, head_size) + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_bias = torch.empty((0, head_size)) + for i in range(num_query_groups): + qkv_bias = torch.cat((qkv_bias, q[i * heads_per_group : (i + 1) * heads_per_group, :])) + qkv_bias = torch.cat((qkv_bias, k[i : i + 1, :])) + qkv_bias = torch.cat((qkv_bias, v[i : i + 1, :])) + qkv_bias = qkv_bias.reshape( + [ + head_size * (head_num + 2 * num_query_groups), + ] + ) + return qkv_bias + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.bias", + target_key=( + "model.layers.*.self_attn.q_proj.bias", + "model.layers.*.self_attn.k_proj.bias", + "model.layers.*.self_attn.v_proj.bias", + ), +) +def _export_qkv_bias(ctx: io.TransformCTX, qkv_bias): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_bias = qkv_bias[q_slice].reshape(-1).cpu() + k_bias = qkv_bias[k_slice].reshape(-1).cpu() + v_bias = qkv_bias[v_slice].reshape(-1).cpu() + + return q_bias, k_bias, v_bias + + +@io.state_transform( + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0).float() + + +@io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), +) +def _export_linear_fc1(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj + + +__all__ = [ + "Qwen2Config", + "Qwen2Config500M", + "Qwen2Config1P5B", + "Qwen2Config7B", + "Qwen2Config72B", + "Qwen2Model", +] diff --git a/nemo/collections/llm/gpt/model/starcoder.py b/nemo/collections/llm/gpt/model/starcoder.py new file mode 100644 index 000000000000..e99b707964fe --- /dev/null +++ b/nemo/collections/llm/gpt/model/starcoder.py @@ -0,0 +1,206 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch.nn.functional as F +from torch import nn + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown + +if TYPE_CHECKING: + from transformers import GPTBigCodeConfig as HFStarcoderConfig + from transformers import GPTBigCodeForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +@dataclass +class StarcoderConfig(GPTConfig): + # configs that are common across model sizes + normalization: str = "LayerNorm" + activation_func: Callable = F.gelu + add_bias_linear: bool = True + seq_length: int = 8192 + position_embedding_type: str = "learned_absolute" + hidden_dropout: float = 0.2 + attention_dropout: float = 0.2 + init_method_std: float = 0.01 + layernorm_epsilon: float = 1e-5 + share_embeddings_and_output_weights: bool = False + kv_channels: int = None + num_query_groups: int = 1 + attention_softmax_in_fp32: bool = True + bias_activation_fusion: bool = True + bias_dropout_fusion: bool = True + + +@dataclass +class StarcoderConfig15B(StarcoderConfig): + num_layers: int = 40 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_attention_heads: int = 48 + init_method_std: float = 0.02 + + +class StarcoderModel(GPTModel): + def __init__( + self, + config: Annotated[Optional[StarcoderConfig], Config[StarcoderConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__( + config or StarcoderConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform + ) + + +@io.model_importer(StarcoderModel, "hf") +class HFStarcoderImporter(io.ModelConnector["GPTBigCodeForCausalLM", StarcoderModel]): + def init(self) -> StarcoderModel: + return StarcoderModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import GPTBigCodeForCausalLM + + source = GPTBigCodeForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Starcoder model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "transformer.wte.weight": "embedding.word_embeddings.weight", + "transformer.wpe.weight": "embedding.position_embeddings.weight", + "transformer.h.*.attn.c_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "transformer.h.*.attn.c_proj.bias": "decoder.layers.*.self_attention.linear_proj.bias", + "transformer.h.*.attn.c_attn.weight": "decoder.layers.*.self_attention.linear_qkv.weight", + "transformer.h.*.attn.c_attn.bias": "decoder.layers.*.self_attention.linear_qkv.bias", + "transformer.h.*.mlp.c_fc.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "transformer.h.*.mlp.c_fc.bias": "decoder.layers.*.mlp.linear_fc1.bias", + "transformer.h.*.mlp.c_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "transformer.h.*.mlp.c_proj.bias": "decoder.layers.*.mlp.linear_fc2.bias", + "transformer.h.*.ln_1.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "transformer.h.*.ln_1.bias": "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "transformer.h.*.ln_2.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "transformer.h.*.ln_2.bias": "decoder.layers.*.mlp.linear_fc1.layer_norm_bias", + "transformer.ln_f.weight": "decoder.final_layernorm.weight", + "transformer.ln_f.bias": "decoder.final_layernorm.bias", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping) + + @property + def tokenizer(self) -> "AutoTokenizer": + return AutoTokenizer(str(self)) + + @property + def config(self) -> StarcoderConfig: + from transformers import GPTBigCodeConfig as HFStarcoderConfig + + source = HFStarcoderConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = StarcoderConfig( + num_layers=source.n_layer, + hidden_size=source.n_embd, + ffn_hidden_size=source.n_inner, + num_attention_heads=source.n_head, + init_method_std=source.initializer_range, + seq_length=source.n_positions, + layernorm_epsilon=source.layer_norm_epsilon, + num_query_groups=1, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.model_exporter(StarcoderModel, "hf") +class HFStarcoderExporter(io.ModelConnector[StarcoderModel, "GPTBigCodeForCausalLM"]): + def init(self) -> "GPTBigCodeForCausalLM": + from transformers import GPTBigCodeForCausalLM + + return GPTBigCodeForCausalLM._from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "transformer.wte.weight", + "embedding.position_embeddings.weight": "transformer.wpe.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "transformer.h.*.attn.c_proj.weight", + "decoder.layers.*.self_attention.linear_proj.bias": "transformer.h.*.attn.c_proj.bias", + "decoder.layers.*.self_attention.linear_qkv.weight": "transformer.h.*.attn.c_attn.weight", + "decoder.layers.*.self_attention.linear_qkv.bias": "transformer.h.*.attn.c_attn.bias", + "decoder.layers.*.mlp.linear_fc1.weight": "transformer.h.*.mlp.c_fc.weight", + "decoder.layers.*.mlp.linear_fc1.bias": "transformer.h.*.mlp.c_fc.bias", + "decoder.layers.*.mlp.linear_fc2.weight": "transformer.h.*.mlp.c_proj.weight", + "decoder.layers.*.mlp.linear_fc2.bias": "transformer.h.*.mlp.c_proj.bias", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "transformer.h.*.ln_1.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "transformer.h.*.ln_1.bias", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "transformer.h.*.ln_2.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "transformer.h.*.ln_2.bias", + "decoder.final_layernorm.weight": "transformer.ln_f.weight", + "decoder.final_layernorm.bias": "transformer.ln_f.bias", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping) + + @property + def tokenizer(self): + return io.load_context(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFStarcoderConfig": + from transformers import sGPTBigCodeConfig as HFStarcoderConfig + + source: StarcoderConfig = io.load_context(str(self)).model.config + + return HFStarcoderConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + head_dim=( + source.kv_channels + if source.kv_channels is not None + else source.hidden_size // source.num_attention_heads + ), + tie_word_embeddings=source.share_embeddings_and_output_weights, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + vocab_size=self.tokenizer.vocab_size, + ) diff --git a/nemo/collections/llm/gpt/model/starcoder2.py b/nemo/collections/llm/gpt/model/starcoder2.py new file mode 100644 index 000000000000..e53f1bde7012 --- /dev/null +++ b/nemo/collections/llm/gpt/model/starcoder2.py @@ -0,0 +1,383 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, List, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown + +if TYPE_CHECKING: + from transformers import Starcoder2Config as HFStarcoder2Config + from transformers import Starcoder2ForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +@dataclass +class Starcoder2Config(GPTConfig): + # configs that are common across model sizes + normalization: str = "LayerNorm" + activation_func: Callable = F.gelu + add_bias_linear: bool = True + seq_length: int = 16384 + position_embedding_type: str = "rope" + rotary_percent: float = 1.0 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + init_method_std: float = 0.01 + share_embeddings_and_output_weights: bool = False + kv_channels: int = None + num_query_groups: int = None + window_size: Optional[List[int]] = None + apply_query_key_layer_scaling: bool = True + attention_softmax_in_fp32: bool = True + bias_activation_fusion: bool = True + bias_dropout_fusion: bool = True + layernorm_epsilon: float = 1e-5 + + +@dataclass +class Starcoder2Config3B(Starcoder2Config): + num_layers: int = 30 + hidden_size: int = 3072 + ffn_hidden_size: int = 12288 + num_query_groups: int = 2 + num_attention_heads: int = 24 + init_method_std: float = 0.018042 + rotary_base: float = 999999.4420358813 + + +@dataclass +class Starcoder2Config7B(Starcoder2Config): + num_layers: int = 32 + hidden_size: int = 4608 + ffn_hidden_size: int = 18432 + num_query_groups: int = 4 + num_attention_heads: int = 36 + init_method_std: float = 0.018042 + rotary_base: float = 1_000_000 + + +@dataclass +class Starcoder2Config15B(Starcoder2Config): + num_layers: int = 40 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_query_groups: int = 4 + num_attention_heads: int = 48 + init_method_std: float = 0.01275 + rotary_base: float = 100_000 + + +class Starcoder2Model(GPTModel): + def __init__( + self, + config: Annotated[Optional[Starcoder2Config], Config[Starcoder2Config]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__( + config or Starcoder2Config(), optim=optim, tokenizer=tokenizer, model_transform=model_transform + ) + + +@io.model_importer(Starcoder2Model, "hf") +class HFStarcoder2Importer(io.ModelConnector["Starcoder2ForCausalLM", Starcoder2Model]): + def init(self) -> Starcoder2Model: + return Starcoder2Model(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import Starcoder2ForCausalLM + + source = Starcoder2ForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Starcoder2 model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.self_attn.o_proj.bias": "decoder.layers.*.self_attention.linear_proj.bias", + "model.layers.*.mlp.c_fc.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "model.layers.*.mlp.c_fc.bias": "decoder.layers.*.mlp.linear_fc1.bias", + "model.layers.*.mlp.c_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.mlp.c_proj.bias": "decoder.layers.*.mlp.linear_fc2.bias", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.input_layernorm.bias": "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.layers.*.post_attention_layernorm.bias": "decoder.layers.*.mlp.linear_fc1.layer_norm_bias", + "model.norm.weight": "decoder.final_layernorm.weight", + "model.norm.bias": "decoder.final_layernorm.bias", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv_bias, _import_qkv_weight]) + + @property + def tokenizer(self) -> "AutoTokenizer": + return AutoTokenizer(str(self)) + + @property + def config(self) -> Starcoder2Config: + from transformers import Starcoder2Config as HFStarcoder2Config + + source = HFStarcoder2Config.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = Starcoder2Config( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + seq_length=source.max_position_embeddings, + layernorm_epsilon=source.norm_epsilon, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.model_exporter(Starcoder2Model, "hf") +class HFStarcoder2Exporter(io.ModelConnector[Starcoder2Model, "Starcoder2ForCausalLM"]): + def init(self) -> "Starcoder2ForCausalLM": + from transformers import Starcoder2ForCausalLM + + return Starcoder2ForCausalLM._from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.self_attention.linear_proj.bias": "model.layers.*.self_attn.o_proj.bias", + "decoder.layers.*.mlp.linear_fc1.weight": "model.layers.*.mlp.c_fc.weight", + "decoder.layers.*.mlp.linear_fc1.bias": "model.layers.*.mlp.c_fc.bias", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.c_proj.weight", + "decoder.layers.*.mlp.linear_fc2.bias": "model.layers.*.mlp.c_proj.bias", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.layers.*.input_layernorm.bias", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.layers.*.post_attention_layernorm.bias", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.final_layernorm.bias": "model.norm.bias", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv_weight, _export_qkv_bias]) + + @property + def tokenizer(self): + return io.load_context(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFStarcoder2Config": + from transformers import Starcoder2Config as HFStarcoder2Config + + source: Starcoder2Config = io.load_context(str(self)).model.config + + return HFStarcoder2Config( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + head_dim=( + source.kv_channels + if source.kv_channels is not None + else source.hidden_size // source.num_attention_heads + ), + tie_word_embeddings=source.share_embeddings_and_output_weights, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + partial_rotary_factor=source.rotary_percent, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv_weight(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.bias", + "model.layers.*.self_attn.k_proj.bias", + "model.layers.*.self_attn.v_proj.bias", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.bias", +) +def _import_qkv_bias(ctx: io.TransformCTX, qb, kb, vb): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + new_q_bias_tensor_shape = (head_num, head_size) + new_kv_bias_tensor_shape = (num_query_groups, head_size) + + qb = qb.view(*new_q_bias_tensor_shape) + kb = kb.view(*new_kv_bias_tensor_shape) + vb = vb.view(*new_kv_bias_tensor_shape) + + qkv_bias_l = [] + for i in range(num_query_groups): + qkv_bias_l.append(qb[i * heads_per_group : (i + 1) * heads_per_group, :]) + qkv_bias_l.append(kb[i : i + 1, :]) + qkv_bias_l.append(vb[i : i + 1, :]) + + qkv_bias = torch.cat(qkv_bias_l) + qkv_bias = qkv_bias.reshape([head_size * (head_num + 2 * num_query_groups)]) + + return qkv_bias + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv_weight(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.bias", + target_key=( + "model.layers.*.self_attn.q_proj.bias", + "model.layers.*.self_attn.k_proj.bias", + "model.layers.*.self_attn.v_proj.bias", + ), +) +def _export_qkv_bias(ctx: io.TransformCTX, qkv_bias): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_bias = qkv_bias[q_slice].reshape(-1).cpu() + k_bias = qkv_bias[k_slice].reshape(-1).cpu() + v_bias = qkv_bias[v_slice].reshape(-1).cpu() + + return q_bias, k_bias, v_bias diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index d9fb5cc61f38..950ca6db7ac6 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -1,4 +1,19 @@ -from nemo.collections.llm.recipes import llama3_8b, llama3_8b_16k, llama3_8b_64k, llama3_70b, mistral +from nemo.collections.llm.recipes import ( + llama3_8b, + llama3_8b_16k, + llama3_8b_64k, + llama3_70b, + llama3_70b_16k, + llama3_70b_64k, + mistral, + mixtral_8x3b, + mixtral_8x3b_16k, + mixtral_8x3b_64k, + mixtral_8x7b, + mixtral_8x7b_16k, + mixtral_8x7b_64k, + mixtral_8x22b, +) from nemo.collections.llm.recipes.log.default import default_log, default_resume from nemo.collections.llm.recipes.optim import adam @@ -7,7 +22,16 @@ "llama3_8b_16k", "llama3_8b_64k", "llama3_70b", + "llama3_70b_16k", + "llama3_70b_64k", "mistral", + "mixtral_8x3b", + "mixtral_8x3b_16k", + "mixtral_8x3b_64k", + "mixtral_8x7b", + "mixtral_8x7b_16k", + "mixtral_8x7b_64k", + "mixtral_8x22b", "adam", "default_log", "default_resume", diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index 4b99aef74a30..c784989ac370 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -13,6 +13,7 @@ from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin from nemo.collections.llm.utils import Config, Partial from nemo.utils.exp_manager import TimingCallback @@ -47,11 +48,6 @@ def trainer( ckpt_include_optimizer=True, ckpt_async_save=True, ckpt_parallel_load=True, - ddp=Config( - DistributedDataParallelConfig, - check_for_nan_in_grad=True, - grad_reduce_in_fp32=True, - ), ) trainer = Config( @@ -66,7 +62,7 @@ def trainer( log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, - plugins=Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), + plugins=bf16_mixed_plugin(), strategy=strategy, use_distributed_sampler=False, val_check_interval=2000, diff --git a/nemo/collections/llm/recipes/llama3_70b_16k.py b/nemo/collections/llm/recipes/llama3_70b_16k.py new file mode 100644 index 000000000000..8829aa6b407b --- /dev/null +++ b/nemo/collections/llm/recipes/llama3_70b_16k.py @@ -0,0 +1,59 @@ +from typing import Callable + +import torch + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import llama3_70b +from nemo.collections.llm.utils import Partial + +NAME = "llama3_70b_16k" + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + recipe = llama3_70b.pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn + ) + + trainer = llama3_70b.trainer( + tensor_parallelism=2, + pipeline_parallelism=4, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=5, + context_parallelism=2, + sequence_parallelism=True, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = llama3_70b.model() + model.config.seq_length = 16384 + + recipe.model = model + recipe.trainer = trainer + + return recipe + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = llama3_70b.finetune_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node + ) + + trainer = llama3_70b.trainer( + tensor_parallelism=2, + pipeline_parallelism=4, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=5, + context_parallelism=2, + sequence_parallelism=True, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = llama3_70b.model() + model.config.seq_length = 16384 + + recipe.model = model + recipe.trainer = trainer + + return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b_64k.py b/nemo/collections/llm/recipes/llama3_70b_64k.py new file mode 100644 index 000000000000..33f46f767a4d --- /dev/null +++ b/nemo/collections/llm/recipes/llama3_70b_64k.py @@ -0,0 +1,59 @@ +from typing import Callable + +import torch + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import llama3_70b +from nemo.collections.llm.utils import Partial + +NAME = "llama3_70b_64k" + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + recipe = llama3_70b.pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn + ) + + trainer = llama3_70b.trainer( + tensor_parallelism=8, + pipeline_parallelism=4, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=5, + context_parallelism=8, + sequence_parallelism=True, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = llama3_70b.model() + model.config.seq_length = 65536 + + recipe.model = model + recipe.trainer = trainer + + return recipe + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = llama3_70b.finetune_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node + ) + + trainer = llama3_70b.trainer( + tensor_parallelism=2, + pipeline_parallelism=4, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=5, + context_parallelism=2, + sequence_parallelism=True, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = llama3_70b.model() + model.config.seq_length = 65536 + + recipe.model = model + recipe.trainer = trainer + + return recipe diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index d70366f6c5ed..340cfbdf6e26 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -13,6 +13,7 @@ from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin from nemo.collections.llm.utils import Config, Partial from nemo.utils.exp_manager import TimingCallback @@ -47,11 +48,6 @@ def trainer( ckpt_include_optimizer=True, ckpt_async_save=True, ckpt_parallel_load=True, - ddp=Config( - DistributedDataParallelConfig, - check_for_nan_in_grad=True, - grad_reduce_in_fp32=True, - ), ) trainer = Config( @@ -66,7 +62,7 @@ def trainer( log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, - plugins=Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), + plugins=bf16_mixed_plugin(), strategy=strategy, use_distributed_sampler=False, val_check_interval=2000, diff --git a/nemo/collections/llm/recipes/llama3_8b_16k.py b/nemo/collections/llm/recipes/llama3_8b_16k.py index 8bb2b636eba0..a57b4ef37298 100644 --- a/nemo/collections/llm/recipes/llama3_8b_16k.py +++ b/nemo/collections/llm/recipes/llama3_8b_16k.py @@ -32,7 +32,7 @@ def pretrain_recipe( recipe.model = model recipe.trainer = trainer - return trainer + return recipe def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: @@ -56,4 +56,4 @@ def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: recipe.model = model recipe.trainer = trainer - return trainer + return recipe diff --git a/nemo/collections/llm/recipes/llama3_8b_64k.py b/nemo/collections/llm/recipes/llama3_8b_64k.py index b42e1e53399e..d06c9b08a716 100644 --- a/nemo/collections/llm/recipes/llama3_8b_64k.py +++ b/nemo/collections/llm/recipes/llama3_8b_64k.py @@ -32,7 +32,7 @@ def pretrain_recipe( recipe.model = model recipe.trainer = trainer - return trainer + return recipe def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: @@ -56,4 +56,4 @@ def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: recipe.model = model recipe.trainer = trainer - return trainer + return recipe diff --git a/nemo/collections/llm/recipes/log/default.py b/nemo/collections/llm/recipes/log/default.py index dc18565a0e06..4d5e9223b535 100644 --- a/nemo/collections/llm/recipes/log/default.py +++ b/nemo/collections/llm/recipes/log/default.py @@ -10,14 +10,19 @@ def tensorboard_logger(name: str, save_dir: str = "tb_logs") -> Config[TensorBoa return Config(TensorBoardLogger, save_dir=save_dir, name=name) -def wandb_logger(project: str, name: str) -> Config[WandbLogger]: - return Config( +def wandb_logger(project: str, name: str, entity: Optional[str] = None) -> Config[WandbLogger]: + cfg = Config( WandbLogger, project=project, name=name, config={}, ) + if entity: + cfg.entity = entity + + return cfg + def default_log( ckpt_dir: str, diff --git a/nemo/collections/llm/recipes/mixtral_8x22b.py b/nemo/collections/llm/recipes/mixtral_8x22b.py new file mode 100644 index 000000000000..aaf0149dbdac --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x22b.py @@ -0,0 +1,113 @@ +from typing import Callable, Optional + +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.squad import SquadDataModule +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x22B, MixtralModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.utils import Config, Partial +from nemo.utils.exp_manager import TimingCallback + +NAME = "mixtral_8x22b" + + +def model() -> Config[pl.LightningModule]: + return Config(MixtralModel, config=Config(MixtralConfig8x22B)) + + +def trainer( + tensor_parallelism: int, + pipeline_parallelism: int, + pipeline_parallelism_type: Optional[torch.dtype], + virtual_pipeline_parallelism: Optional[int], + context_parallelism: int, + sequence_parallelism: bool, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[Config[Callback]]] = None, +) -> Config[nl.Trainer]: + strategy = Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + ), + ) + + trainer = Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + gradient_clip_val=1.0, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + return Partial( + fn, + model=model(), + trainer=trainer( + tensor_parallelism=8, + pipeline_parallelism=1, + pipeline_parallelism_type=None, + virtual_pipeline_parallelism=None, + context_parallelism=1, + sequence_parallelism=True, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[Config(TimingCallback)], + ), + data=Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(ckpt_dir=ckpt_dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + +def hf_resume() -> Config[nl.AutoResume]: + return Config(nl.AutoResume, import_path="hf://mistralai/Mixtral-8x22B-v0.1") + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune + ) + recipe.resume = hf_resume() + recipe.peft = Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + recipe.data = Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1) + return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x22b_4k.py b/nemo/collections/llm/recipes/mixtral_8x22b_4k.py deleted file mode 100644 index 5a29cca38506..000000000000 --- a/nemo/collections/llm/recipes/mixtral_8x22b_4k.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytorch_lightning as pl - -from nemo import lightning as nl -from nemo.collections.llm.api import finetune, pretrain -from nemo.collections.llm.gpt.data.api import squad -from nemo.collections.llm.gpt.model.llama import MixtralConfig8x22B, MixtralModel -from nemo.collections.llm.peft.api import gpt_lora -from nemo.collections.llm.recipes.log.default import default_log -from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing -from nemo.collections.llm.utils import Partial, factory - -NAME = "mixtral_8x22b_4k" - - -@factory(name=NAME) -def model() -> pl.LightningModule: - return MixtralModel(MixtralConfig8x22B(seq_length=4096)) - - -@factory(name=NAME) -def trainer(devices=8) -> nl.Trainer: - strategy = nl.MegatronStrategy( - tensor_model_parallel_size=8, - sequence_parallel=True, - ) - - return nl.Trainer( - devices=devices, - max_steps=100, - accelerator="gpu", - strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), - ) - - -@factory(name=NAME + "_hf") -def hf_resume() -> nl.AutoResume: - return nl.AutoResume(import_path="hf://mistralai/Mixtral-8x22B-v0.1") - - -@factory(name=NAME, for_task="llm.pretrain") -def pretrain_recipe() -> Partial: - return Partial( - pretrain, - model=model, - trainer=trainer, - data=squad, - log=default_log, - optim=distributed_fused_adam_with_cosine_annealing(), - ) - - -@factory(name=NAME, for_task="llm.finetune") -def finetune_recipe() -> Partial: - return Partial( - finetune, - model=model, - trainer=trainer, - data=squad, - log=default_log, - optim=distributed_fused_adam_with_cosine_annealing(), - peft=gpt_lora, - resume=hf_resume, - ) diff --git a/nemo/collections/llm/recipes/mixtral_8x3b.py b/nemo/collections/llm/recipes/mixtral_8x3b.py new file mode 100644 index 000000000000..223fe68af05d --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x3b.py @@ -0,0 +1,116 @@ +from typing import Callable, Optional + +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.squad import SquadDataModule +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x3B, MixtralModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.utils import Config, Partial +from nemo.utils.exp_manager import TimingCallback + +NAME = "mixtral_8x3b" + + +def model() -> Config[pl.LightningModule]: + return Config(MixtralModel, config=Config(MixtralConfig8x3B)) + + +def trainer( + tensor_parallelism: int, + pipeline_parallelism: int, + pipeline_parallelism_type: Optional[torch.dtype], + virtual_pipeline_parallelism: Optional[int], + context_parallelism: int, + sequence_parallelism: bool, + expert_parallelism: int, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[Config[Callback]]] = None, +) -> Config[nl.Trainer]: + strategy = Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + expert_model_parallel_size=expert_parallelism, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + ), + ) + + trainer = Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + gradient_clip_val=1.0, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + return Partial( + fn, + model=model(), + trainer=trainer( + tensor_parallelism=4, + pipeline_parallelism=1, + pipeline_parallelism_type=None, + virtual_pipeline_parallelism=None, + context_parallelism=1, + sequence_parallelism=True, + expert_parallelism=1, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[Config(TimingCallback)], + ), + data=Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(ckpt_dir=ckpt_dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + +def hf_resume() -> Config[nl.AutoResume]: + return Config(nl.AutoResume, import_path="hf://mistralai/Mixtral-8x7B-v0.1") + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune + ) + recipe.resume = hf_resume() + recipe.peft = Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + recipe.data = Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1) + return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x3b_16k.py b/nemo/collections/llm/recipes/mixtral_8x3b_16k.py new file mode 100644 index 000000000000..e496349a35d6 --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x3b_16k.py @@ -0,0 +1,61 @@ +from typing import Callable + +import torch + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import mixtral_8x3b +from nemo.collections.llm.utils import Partial + +NAME = "mixtral_8x3b_16k" + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + recipe = mixtral_8x3b.pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn + ) + + trainer = mixtral_8x3b.trainer( + tensor_parallelism=2, + pipeline_parallelism=2, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=8, + context_parallelism=2, + sequence_parallelism=True, + expert_parallelism=2, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = mixtral_8x3b.model() + model.config.seq_length = 16384 + + recipe.model = model + recipe.trainer = trainer + + return recipe + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = mixtral_8x3b.finetune_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node + ) + + trainer = mixtral_8x3b.trainer( + tensor_parallelism=2, + pipeline_parallelism=2, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=8, + context_parallelism=2, + sequence_parallelism=True, + expert_parallelism=2, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = mixtral_8x3b.model() + model.config.seq_length = 16384 + + recipe.model = model + recipe.trainer = trainer + + return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x3b_64k.py b/nemo/collections/llm/recipes/mixtral_8x3b_64k.py new file mode 100644 index 000000000000..f034f30ecd94 --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x3b_64k.py @@ -0,0 +1,61 @@ +from typing import Callable + +import torch + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import mixtral_8x3b +from nemo.collections.llm.utils import Partial + +NAME = "mixtral_8x3b_64k" + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + recipe = mixtral_8x3b.pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn + ) + + trainer = mixtral_8x3b.trainer( + tensor_parallelism=4, + pipeline_parallelism=4, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=8, + context_parallelism=4, + sequence_parallelism=True, + expert_parallelism=4, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = mixtral_8x3b.model() + model.config.seq_length = 65536 + + recipe.model = model + recipe.trainer = trainer + + return recipe + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = mixtral_8x3b.finetune_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node + ) + + trainer = mixtral_8x3b.trainer( + tensor_parallelism=2, + pipeline_parallelism=2, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=8, + context_parallelism=4, + sequence_parallelism=True, + expert_parallelism=2, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = mixtral_8x3b.model() + model.config.seq_length = 65536 + + recipe.model = model + recipe.trainer = trainer + + return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py new file mode 100644 index 000000000000..1710727bd711 --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -0,0 +1,116 @@ +from typing import Callable, Optional + +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.squad import SquadDataModule +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.utils import Config, Partial +from nemo.utils.exp_manager import TimingCallback + +NAME = "mixtral_8x7b" + + +def model() -> Config[pl.LightningModule]: + return Config(MixtralModel, config=Config(MixtralConfig8x7B)) + + +def trainer( + tensor_parallelism: int, + pipeline_parallelism: int, + pipeline_parallelism_type: Optional[torch.dtype], + virtual_pipeline_parallelism: Optional[int], + context_parallelism: int, + sequence_parallelism: bool, + expert_parallelism: int, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[Config[Callback]]] = None, +) -> Config[nl.Trainer]: + strategy = Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + expert_model_parallel_size=expert_parallelism, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + ), + ) + + trainer = Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + gradient_clip_val=1.0, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + return Partial( + fn, + model=model(), + trainer=trainer( + tensor_parallelism=8, + pipeline_parallelism=1, + pipeline_parallelism_type=None, + virtual_pipeline_parallelism=None, + context_parallelism=1, + sequence_parallelism=True, + expert_parallelism=1, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[Config(TimingCallback)], + ), + data=Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(ckpt_dir=ckpt_dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + +def hf_resume() -> Config[nl.AutoResume]: + return Config(nl.AutoResume, import_path="hf://mistralai/Mixtral-8x7B-v0.1") + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune + ) + recipe.resume = hf_resume() + recipe.peft = Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + recipe.data = Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1) + return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_16k.py b/nemo/collections/llm/recipes/mixtral_8x7b_16k.py new file mode 100644 index 000000000000..352069fc6831 --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x7b_16k.py @@ -0,0 +1,61 @@ +from typing import Callable + +import torch + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import mixtral_8x7b +from nemo.collections.llm.utils import Partial + +NAME = "mixtral_8x7b_16k" + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + recipe = mixtral_8x7b.pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn + ) + + trainer = mixtral_8x7b.trainer( + tensor_parallelism=2, + pipeline_parallelism=4, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=8, + context_parallelism=4, + sequence_parallelism=True, + expert_parallelism=8, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = mixtral_8x7b.model() + model.config.seq_length = 16384 + + recipe.model = model + recipe.trainer = trainer + + return recipe + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = mixtral_8x7b.finetune_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node + ) + + trainer = mixtral_8x7b.trainer( + tensor_parallelism=2, + pipeline_parallelism=2, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=8, + context_parallelism=2, + sequence_parallelism=True, + expert_parallelism=8, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = mixtral_8x7b.model() + model.config.seq_length = 16384 + + recipe.model = model + recipe.trainer = trainer + + return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_4k.py b/nemo/collections/llm/recipes/mixtral_8x7b_4k.py deleted file mode 100644 index 5afa3cd072f6..000000000000 --- a/nemo/collections/llm/recipes/mixtral_8x7b_4k.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytorch_lightning as pl - -from nemo import lightning as nl -from nemo.collections.llm.api import finetune, pretrain -from nemo.collections.llm.gpt.data.api import squad -from nemo.collections.llm.gpt.model.llama import MixtralConfig8x7B, MixtralModel -from nemo.collections.llm.peft.api import gpt_lora -from nemo.collections.llm.recipes.log.default import default_log -from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing -from nemo.collections.llm.utils import Partial, factory - -NAME = "mixtral_8x7b_4k" - - -@factory(name=NAME) -def model() -> pl.LightningModule: - return MixtralModel(MixtralConfig8x7B(seq_length=4096)) - - -@factory(name=NAME) -def trainer(devices=8) -> nl.Trainer: - strategy = nl.MegatronStrategy( - tensor_model_parallel_size=8, - sequence_parallel=True, - ) - - return nl.Trainer( - devices=devices, - max_steps=100, - accelerator="gpu", - strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), - ) - - -@factory(name=NAME + "_hf") -def hf_resume() -> nl.AutoResume: - return nl.AutoResume(import_path="hf://mistralai/Mixtral-8x7B-v0.1") - - -@factory(name=NAME, for_task="llm.pretrain") -def pretrain_recipe() -> Partial: - return Partial( - pretrain, - model=model, - trainer=trainer, - data=squad, - log=default_log, - optim=distributed_fused_adam_with_cosine_annealing(), - ) - - -@factory(name=NAME, for_task="llm.finetune") -def finetune_recipe() -> Partial: - return Partial( - finetune, - model=model, - trainer=trainer, - data=squad, - log=default_log, - optim=distributed_fused_adam_with_cosine_annealing(), - peft=gpt_lora, - resume=hf_resume, - ) diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_64k.py b/nemo/collections/llm/recipes/mixtral_8x7b_64k.py new file mode 100644 index 000000000000..503c83ecb66a --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x7b_64k.py @@ -0,0 +1,61 @@ +from typing import Callable + +import torch + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import mixtral_8x7b +from nemo.collections.llm.utils import Partial + +NAME = "mixtral_8x7b_64k" + + +def pretrain_recipe( + name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain +) -> Partial: + recipe = mixtral_8x7b.pretrain_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn + ) + + trainer = mixtral_8x7b.trainer( + tensor_parallelism=4, + pipeline_parallelism=4, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=8, + context_parallelism=4, + sequence_parallelism=True, + expert_parallelism=8, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = mixtral_8x7b.model() + model.config.seq_length = 65536 + + recipe.model = model + recipe.trainer = trainer + + return recipe + + +def finetune_recipe(name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int) -> Partial: + recipe = mixtral_8x7b.finetune_recipe( + name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node + ) + + trainer = mixtral_8x7b.trainer( + tensor_parallelism=2, + pipeline_parallelism=4, + pipeline_parallelism_type=torch.bfloat16, + virtual_pipeline_parallelism=8, + context_parallelism=2, + sequence_parallelism=True, + expert_parallelism=8, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ) + model = mixtral_8x7b.model() + model.config.seq_length = 65536 + + recipe.model = model + recipe.trainer = trainer + + return recipe diff --git a/nemo/collections/llm/recipes/precision/__init__.py b/nemo/collections/llm/recipes/precision/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/llm/recipes/precision/mixed_precision.py b/nemo/collections/llm/recipes/precision/mixed_precision.py new file mode 100644 index 000000000000..6a9cb64404ce --- /dev/null +++ b/nemo/collections/llm/recipes/precision/mixed_precision.py @@ -0,0 +1,26 @@ +import torch + +from nemo.collections.llm.utils import Config +from nemo.lightning.pytorch.plugins.mixed_precision import MegatronMixedPrecision + + +def bf16_mixed_plugin() -> Config[MegatronMixedPrecision]: + return Config( + MegatronMixedPrecision, + precision="bf16-mixed", + params_dtype=torch.bfloat16, + pipeline_dtype=torch.bfloat16, + autocast_enabled=False, + grad_reduce_in_fp32=True, + ) + + +def fp16_mixed_plugin() -> Config[MegatronMixedPrecision]: + return Config( + MegatronMixedPrecision, + precision="16-mixed", + params_dtype=torch.half, + pipeline_dtype=torch.half, + autocast_enabled=False, + grad_reduce_in_fp32=False, + ) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py index 3d5d7effc9de..ef09c7ff068e 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py @@ -73,7 +73,7 @@ def get_prompt_template_example(special_tokens): def identify_start_index_of_subsequence(subsequence, sequence): - """ find the location of the small tensor in the large tensor. + """find the location of the small tensor in the large tensor. e.g. small = [1,3], large = [2,3,1,3], returns 2 small = [3,2], large = [2,3,1,3], returns -1 Args: @@ -100,7 +100,7 @@ def _mask_targets( label_start_ids, num_turn_start_tokens, ): - """ This function masks the tokens so the loss is computed only on the non-masked role's responses. + """This function masks the tokens so the loss is computed only on the non-masked role's responses. For 'TEXT_TO_VALUE' type, the loss is computed on the value attributes. Args: @@ -373,8 +373,9 @@ def collate_fn(self, batch): max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 8)) assert max_length <= self.max_seq_length - attention_mask = [self._create_attention_mask(max_length) for _ in batch] - attention_mask = torch.stack(attention_mask) + if not self.get_attention_mask_from_fusion: + attention_mask = [self._create_attention_mask(max_length) for _ in batch] + attention_mask = torch.stack(attention_mask) position_ids = [list(range(max_length)) for _ in batch] position_ids = torch.LongTensor(position_ids) input_ids = torch.LongTensor( @@ -389,7 +390,6 @@ def collate_fn(self, batch): processed_batch = { 'tokens': input_ids, 'labels': labels, - 'attention_mask': attention_mask, 'loss_mask': loss_mask, 'position_ids': position_ids, 'contexts': contexts, @@ -398,4 +398,7 @@ def collate_fn(self, batch): 'metadata': metadata, } + if not self.get_attention_mask_from_fusion: + processed_batch['attention_mask'] = attention_mask + return processed_batch diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 571d93120308..327a9990801b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1203,7 +1203,7 @@ def get_batch_on_this_context_parallel_rank(self, batch): if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() for key, val in batch.items(): - if val is not None: + if val is not None and key != "context_lengths": seq_dim = 1 if key != 'attention_mask' else 2 val = val.view( *val.shape[0:seq_dim], diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 9c2372ef38ca..08bc5501363c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -301,7 +301,7 @@ def _build_dataset(self, data_cfg, is_train=True): index_mapping_dir=data_cfg.get('index_mapping_dir', None), prompt_template=data_cfg.get('prompt_template', None), ceil_to_power_2=data_cfg.get('ceil_to_power_2', False), - get_attention_mask_from_fusion=data_cfg.get('get_attention_mask_from_fusion', False), + get_attention_mask_from_fusion=data_cfg.get('get_attention_mask_from_fusion', True), global_sample_mapping=data_cfg.get('global_sample_mapping', False), virtual_tokens=self.virtual_tokens, tokens_to_generate=data_cfg.get( diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 4f9f04527038..29eea2d54664 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -177,6 +177,7 @@ def __init__( model_parallel_config = ModelParallelConfig() self._sequence_parallel = model_parallel_config.sequence_parallel model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer + self.config = model_parallel_config if input_is_parallel: self.linear_in = RowParallelLinear( @@ -298,8 +299,14 @@ def forward(self, x): # this function also handles the backward pass correctly x = gather_from_sequence_parallel_region(x) + if self.config.cpu_offloading and self.config.cpu_offloading_activations: + x.activation_offloading = True x, _ = self.linear_in(x) # (@adithyare) ColumnLinear returns output and bias, we are ignoring the bias term. + x = self.activation(x) + + if self.config.cpu_offloading and self.config.cpu_offloading_activations: + x.activation_offloading = True x, _ = self.linear_out(x) if self._sequence_parallel and self.input_is_parallel: diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 5aaac6755601..601cb7a4d7e8 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -93,7 +93,7 @@ def parallel_lm_logits( tensor_model_parallel = parallel_state.get_tensor_model_parallel_world_size() > 1 # async grad allreduce can only be used when not using sequence parallelism - async_grad_allreduce = async_tensor_model_parallel_allreduce and tensor_model_parallel and not sequence_parallel + allreduce_dgrad = async_tensor_model_parallel_allreduce and tensor_model_parallel and not sequence_parallel # copy input_ to model parallel region if needed if async_tensor_model_parallel_allreduce or sequence_parallel: @@ -108,7 +108,7 @@ def parallel_lm_logits( weight=word_embeddings_weight, bias=bias, gradient_accumulation_fusion=gradient_accumulation_fusion, - async_grad_allreduce=async_grad_allreduce, + allreduce_dgrad=allreduce_dgrad, sequence_parallel=sequence_parallel, ) diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 4cbadd87fe52..56496d56bc07 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -16,28 +16,8 @@ from dataclasses import MISSING, dataclass from typing import Dict, List, Optional -import nemo -from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelTokenizer -from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer -from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer -from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer -from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer -from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer -from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import get_huggingface_pretrained_lm_models_list -from nemo.collections.nlp.modules.common.lm_utils import get_pretrained_lm_models_list -from nemo.collections.nlp.parts.nlp_overrides import HAVE_MEGATRON_CORE from nemo.utils import logging -try: - from nemo.collections.nlp.modules.common.megatron.megatron_utils import get_megatron_tokenizer - - HAVE_MEGATRON_CORE = True - -except (ImportError, ModuleNotFoundError): - HAVE_MEGATRON_CORE = False - - __all__ = ['get_tokenizer', 'get_tokenizer_list'] @@ -96,46 +76,61 @@ def get_tokenizer( model better learn word compositionality and become robust to segmentation errors. It has emperically been shown to improve inference time BLEU scores. """ + if special_tokens is None: special_tokens_dict = {} else: special_tokens_dict = special_tokens if 'megatron' in tokenizer_name: - if not HAVE_MEGATRON_CORE: + try: + from nemo.collections.nlp.modules.common.megatron.megatron_utils import ( + get_megatron_merges_file, + get_megatron_tokenizer, + get_megatron_vocab_file, + ) + except (ImportError, ModuleNotFoundError): raise ImportError( "Megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) if vocab_file is None: - vocab_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_vocab_file( - tokenizer_name - ) - merges_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_merges_file( - tokenizer_name - ) + vocab_file = get_megatron_vocab_file(tokenizer_name) + merges_file = get_megatron_merges_file(tokenizer_name) tokenizer_name = get_megatron_tokenizer(tokenizer_name) if tokenizer_name == 'sentencepiece': + from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer + logging.info("tokenizer_model: " + str(tokenizer_model)) - return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( + return SentencePieceTokenizer( model_path=tokenizer_model, special_tokens=special_tokens, legacy=True, chat_template=chat_template, ) elif tokenizer_name == 'tiktoken': - return nemo.collections.common.tokenizers.tiktoken_tokenizer.TiktokenTokenizer(vocab_file=vocab_file) + from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer + + return TiktokenTokenizer(vocab_file=vocab_file) elif tokenizer_name == 'word': + from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer + return WordTokenizer(vocab_file=vocab_file, **special_tokens_dict) elif tokenizer_name == 'char': + from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer + return CharTokenizer(vocab_file=vocab_file, **special_tokens_dict) elif tokenizer_name == 'regex': + from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer + return RegExTokenizer().load_tokenizer(regex_file=tokenizer_model, vocab_file=vocab_file) logging.info( f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_name}, vocab_file: {vocab_file}, merges_files: {merges_file}, " f"special_tokens_dict: {special_tokens_dict}, and use_fast: {use_fast}" ) + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + return AutoTokenizer( pretrained_model_name=tokenizer_name, vocab_file=vocab_file, @@ -183,6 +178,8 @@ def get_nmt_tokenizer( raise ValueError("No Tokenizer path provided or file does not exist!") if library == 'huggingface': + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + logging.info(f'Getting HuggingFace AutoTokenizer with pretrained_model_name: {model_name}') return AutoTokenizer( pretrained_model_name=model_name, @@ -193,26 +190,32 @@ def get_nmt_tokenizer( trust_remote_code=trust_remote_code, ) elif library == 'sentencepiece': + from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer + logging.info(f'Getting SentencePiece with model: {tokenizer_model}') - return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( + return SentencePieceTokenizer( model_path=tokenizer_model, legacy=legacy, chat_template=chat_template, ) elif library == 'byte-level': + from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelTokenizer + logging.info(f'Using byte-level tokenization') return ByteLevelTokenizer(special_tokens_dict) elif library == 'regex': + from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer + logging.info(f'Using regex tokenization') return RegExTokenizer().load_tokenizer(regex_file=tokenizer_model, vocab_file=vocab_file) elif library == 'megatron': if model_name == 'GPTSentencePieceTokenizer': + from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer + logging.info("tokenizer_model: ") logging.info(tokenizer_model) - return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( - model_path=tokenizer_model, legacy=legacy - ) + return SentencePieceTokenizer(model_path=tokenizer_model, legacy=legacy) if model_name in megatron_tokenizer_model_map: model_name = megatron_tokenizer_model_map[model_name] @@ -223,8 +226,12 @@ def get_nmt_tokenizer( tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file, chat_template=chat_template ) elif library == 'tabular': + from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer + return TabularTokenizer(vocab_file, delimiter=delimiter) elif library == 'tiktoken': + from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer + return TiktokenTokenizer(vocab_file=vocab_file) else: raise NotImplementedError( diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index e9ed34732c36..c8070225d25a 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -22,8 +22,7 @@ from einops import rearrange from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor -from nemo.collections.asr.parts.utils.activations import Snake -from nemo.collections.common.parts.utils import mask_sequence_tensor +from nemo.collections.common.parts.utils import ClampActivation, HalfSnake, Snake, mask_sequence_tensor from nemo.core.classes.common import typecheck from nemo.core.classes.module import NeuralModule from nemo.core.neural_types.elements import ( @@ -75,6 +74,8 @@ def __init__(self, activation: str = "elu", channels: int = 1): self.activation = torch.nn.LeakyReLU() elif activation == "snake": self.activation = Snake(channels) + elif activation == "half_snake": + self.activation = HalfSnake(channels) else: raise ValueError(f"Unknown activation {activation}") @@ -322,6 +323,185 @@ def forward(self, audio_real, audio_gen): return scores_real, scores_gen, fmaps_real, fmaps_gen +class DiscriminatorSTFT(NeuralModule): + """ + Discriminator network from EnCodec for Complex STFT input, but without dilations. + + Args: + filters: number of filters to use in Conv2d layers + lrelu_slope: Slope to use for activations. Leaky relu with slope of 0.1 or 0.2 is recommended for the + stability of the feature matching loss + """ + + def __init__(self, filters: int = 32, lrelu_slope: float = 0.1): + super().__init__() + + self.activation = nn.LeakyReLU(lrelu_slope) + self.conv_layers = nn.ModuleList( + [ + Conv2dNorm(2, filters, kernel_size=(3, 9)), + Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), + Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), + Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), + Conv2dNorm(filters, filters, kernel_size=(3, 3)), + ] + ) + self.conv_post = Conv2dNorm(filters, 1, kernel_size=(3, 3)) + + @property + def input_types(self): + return { + "spec": NeuralType(('B', 'C', 'T_spec', 'D'), VoidType()), + } + + @property + def output_types(self): + return { + "scores": NeuralType(('B', 'C', 'T_spec'), VoidType()), + "fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())], + } + + @typecheck() + def forward(self, spec): + fmap = [] + + # [batch, 2, T_spec, fft] + out = spec + for conv in self.conv_layers: + # [batch, filters, T_spec, fft // strides] + out = conv(inputs=out) + out = self.activation(out) + fmap.append(out) + # [batch, 1, T_spec, fft // 8] + scores = self.conv_post(inputs=out) + fmap.append(scores) + scores = rearrange(scores, "B 1 T C -> B C T") + + return scores, fmap + + +class MultiBandDiscriminatorSTFT(NeuralModule): + """ + Multi-band STFT discriminator proposed in DAC (https://arxiv.org/abs/2306.06546). + + Computes the complex STFT for a given resolution and splits it into sub-bands, + which are given to separate discriminator networks. + + Args: + resolution: STFT resolution, provided as a tuple of 3 integers ordered (num_fft, hop_length, window_length) + stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). + The floats are in the range [0, 1] representing the fraction of all stft bands. + For example for n_fft=1024, the stft output has 513 dimensions. + For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. + """ + + def __init__(self, resolution: Tuple[int], stft_bands: Iterable[Tuple[int]]): + super().__init__() + + self.n_fft, self.hop_length, self.win_length = resolution + self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + self.discriminators = nn.ModuleList([DiscriminatorSTFT() for _ in stft_bands]) + n_stft = self.n_fft // 2 + 1 + self.stft_bands = [(int(band[0] * n_stft), int(band[1] * n_stft)) for band in stft_bands] + + def compute_stft(self, audio): + # [B, fft, T_spec] + fft = torch.stft( + audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + normalized=True, + center=True, + return_complex=True, + ) + fft = rearrange(fft, "B fft T -> B T fft") + # [batch, 2, T_spec, fft] + out = torch.stack([fft.real, fft.imag], dim=1) + return out + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + } + + @property + def output_types(self): + return { + "scores_list": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "fmaps_list": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + } + + @typecheck() + def forward(self, audio): + scores_list = [] + fmap_list = [] + spec = self.compute_stft(audio) + for band, disc in zip(self.stft_bands, self.discriminators): + spec_band = spec[:, :, :, band[0] : band[1]] + score, fmap = disc(spec=spec_band) + scores_list.append(score) + fmap_list.append(fmap) + + return scores_list, fmap_list + + +class MultiResolutionDiscriminatorSTFT(NeuralModule): + """ + Multi-resolution discriminator which creates a multi-band discriminator for each input resolution. + + Args: + resolutions: List of STFT resolutions, each resolution provided as a tuple of 3 integers ordered + (num_fft, hop_length, window_length) + stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). + The floats are in the range [0, 1] representing the fraction of all stft bands. + For example for n_fft=1024, the stft output has 513 dimensions. + For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. + """ + + def __init__(self, resolutions: Iterable[Tuple[int]], stft_bands: Iterable[Tuple[int]]): + super().__init__() + self.discriminators = nn.ModuleList( + [MultiBandDiscriminatorSTFT(resolution=resolution, stft_bands=stft_bands) for resolution in resolutions] + ) + + @property + def input_types(self): + return { + "audio_real": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), + } + + @property + def output_types(self): + return { + "scores_real": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "scores_gen": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "fmaps_real": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + "fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + } + + @typecheck() + def forward(self, audio_real, audio_gen): + scores_real = [] + scores_gen = [] + fmaps_real = [] + fmaps_gen = [] + + for disc in self.discriminators: + score_real_i, fmap_real_i = disc(audio=audio_real) + scores_real = scores_real + score_real_i + fmaps_real = fmaps_real + fmap_real_i + + score_gen_i, fmap_gen_i = disc(audio=audio_gen) + scores_gen = scores_gen + score_gen_i + fmaps_gen = fmaps_gen + fmap_gen_i + + return scores_real, scores_gen, fmaps_real, fmaps_gen + + class Discriminator(NeuralModule): """ Wrapper class which takes a list of discriminators and aggregates the results across them. @@ -868,6 +1048,120 @@ def forward(self, inputs, input_len): return out +class HiFiGANEncoder(NeuralModule): + """ + Audio encoder created by inverting the HiFi-GAN decoder. + + Args: + encoded_dim: Dimension of encoder output. + down_sample_rates: Rate to upsample for each decoder block. The product of the downsample rates will + determine the output token rate. For example 2 * 2 * 8 * 8 = 256 samples per token. + base_channels: Number of filters in the first convolution. The number of channels will be doubled after each + downsample layer. + in_kernel_size: Kernel size of the input convolution. + out_kernel_size: Kernel size of the output convolution. + resblock_kernel_sizes: List of kernel sizes to use in each residual block. + resblock_dilation_sizes: List of dilations to use in each residual block. + activation: Activation to use in residual and downsample layers, defaults to leaky relu. + """ + + def __init__( + self, + encoded_dim: int, + down_sample_rates: Iterable[int] = (2, 2, 8, 8), + base_channels: int = 32, + in_kernel_size: int = 7, + out_kernel_size: int = 7, + resblock_kernel_sizes: Iterable[int] = (3, 7, 11), + resblock_dilation_sizes: Iterable[int] = (1, 3, 5), + activation: str = "lrelu", + ): + assert in_kernel_size > 0 + assert out_kernel_size > 0 + + super().__init__() + + self.down_sample_rates = down_sample_rates + self.pre_conv = Conv1dNorm(in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size) + + in_channels = base_channels + self.activations = nn.ModuleList([]) + self.down_sample_conv_layers = nn.ModuleList([]) + self.res_layers = nn.ModuleList([]) + for i, down_sample_rate in enumerate(self.down_sample_rates): + res_layer = HiFiGANResLayer( + channels=in_channels, + kernel_sizes=resblock_kernel_sizes, + dilations=resblock_dilation_sizes, + activation=activation, + ) + self.res_layers.append(res_layer) + + act = CodecActivation(activation, channels=in_channels) + self.activations.append(act) + + out_channels = 2 * in_channels + kernel_size = 2 * down_sample_rate + + padding = get_down_sample_padding(kernel_size=kernel_size, stride=down_sample_rate) + down_sample_conv = Conv1dNorm( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=down_sample_rate, + padding=padding, + ) + in_channels = out_channels + self.down_sample_conv_layers.append(down_sample_conv) + + self.post_activation = CodecActivation(activation, channels=in_channels) + self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=encoded_dim, kernel_size=out_kernel_size) + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + self.post_conv.remove_weight_norm() + for res_layer in self.res_layers: + res_layer.remove_weight_norm() + for down_sample_conv in self.down_sample_conv_layers: + down_sample_conv.remove_weight_norm() + + @typecheck() + def forward(self, audio, audio_len): + encoded_len = audio_len + audio = rearrange(audio, "B T -> B 1 T") + # [B, C, T_audio] + out = self.pre_conv(inputs=audio, input_len=encoded_len) + for act, res_layer, down_sample_conv, down_sample_rate in zip( + self.activations, self.res_layers, self.down_sample_conv_layers, self.down_sample_rates + ): + # [B, C, T] + out = res_layer(inputs=out, input_len=encoded_len) + out = act(out) + + encoded_len = encoded_len // down_sample_rate + # [B, 2 * C, T / down_sample_rate] + out = down_sample_conv(inputs=out, input_len=encoded_len) + + out = self.post_activation(out) + # [B, encoded_dim, T_encoded] + encoded = self.post_conv(inputs=out, input_len=encoded_len) + return encoded, encoded_len + + class HiFiGANDecoder(NeuralModule): """ Codec decoder using the HiFi-GAN generator architecture. @@ -876,8 +1170,9 @@ class HiFiGANDecoder(NeuralModule): Args: input_dim: Input dimension. - up_sample_rates: Rate to upsample for each decoder block. The product of the upsample rates will - determine the output frame rate. For example 8 * 8 * 2 * 2 = 256 samples per token. + up_sample_rates: Rate to upsample for each decoder block. The product of the upsample rates should be the same + as the overall downsample rate for your encoder. For example, a symmetric encoder/decoder can be created + with encoder downsample rates [2, 2, 8, 8] and decoder upsample rates [8, 8, 2, 2]. base_channels: Number of filters in the first convolution. The number of channels will be cut in half after each upsample layer. in_kernel_size: Kernel size of the input convolution. @@ -885,6 +1180,8 @@ class HiFiGANDecoder(NeuralModule): resblock_kernel_sizes: List of kernel sizes to use in each residual block. resblock_dilation_sizes: List of dilations to use in each residual block. activation: Activation to use in residual and upsample layers, defaults to leaky relu. + output_activation: Activation to apply to output. To produce a valid audio signal, it should output values in + the range [-1.0, 1.0]. Supports "tanh" and "clamp". """ def __init__( @@ -897,6 +1194,7 @@ def __init__( resblock_kernel_sizes: Iterable[int] = (3, 7, 11), resblock_dilation_sizes: Iterable[int] = (1, 3, 5), activation: str = "lrelu", + output_activation: str = "tanh", ): assert in_kernel_size > 0 assert out_kernel_size > 0 @@ -933,7 +1231,12 @@ def __init__( self.post_activation = CodecActivation(activation, channels=in_channels) self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size) - self.out_activation = nn.Tanh() + if output_activation == "tanh": + self.out_activation = nn.Tanh() + elif output_activation == "clamp": + self.out_activation = ClampActivation() + else: + raise ValueError(f"Invalid audio output activation {output_activation}") @property def input_types(self): diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 9feb70cc90a1..c058da52a97a 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -35,8 +35,6 @@ class McoreDistributedOptimizer(torch.optim.Optimizer): def __init__(self, optim): self.defaults = {} self.mcore_optimizer = optim - self.param_groups = self.mcore_optimizer.param_groups - self.state = self.mcore_optimizer.state def zero_grad(self, set_to_none: bool = True): """We only need to zero the model related parameters, i.e., @@ -76,12 +74,39 @@ def step(self, closure): return loss + # Promote state so it can be retrieved or set via + # "optimizer_instance.state" + def _get_state(self): + if hasattr(self, 'mcore_optimizer'): + return self.mcore_optimizer.state + else: + return [] + + def _set_state(self, value): + self.mcore_optimizer.state = value + + state = property(_get_state, _set_state) + def save_parameter_state(self, filename: str): self.mcore_optimizer.save_parameter_state(filename) def load_parameter_state(self, filename: str): self.mcore_optimizer.load_parameter_state(filename) + # Promote param_groups so it can be retrieved or set via + # "optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + if hasattr(self, 'mcore_optimizer'): + return self.mcore_optimizer.param_groups + else: + return [] + + def _set_param_groups(self, value): + self.mcore_optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) + def finish_param_sync(self, model_index): self.mcore_optimizer.finish_param_sync(model_index) diff --git a/nemo/core/utils/k2_utils.py b/nemo/core/utils/k2_utils.py index 3dff6a35d3e3..3e7c2a6f5a70 100644 --- a/nemo/core/utils/k2_utils.py +++ b/nemo/core/utils/k2_utils.py @@ -16,7 +16,7 @@ K2_INSTALLATION_MESSAGE = ( "Could not import `k2`.\n" "Please install k2 in one of the following ways:\n" - "1) (recommended) Run `bash scripts/speech_recognition/k2/setup.sh`\n" + "1) (recommended) Run `bash scripts/installers/install_k2.sh`\n" "2) Use any approach from https://k2-fsa.github.io/k2/installation/index.html " "if your your cuda and pytorch versions are supported.\n" "It is advised to always install k2 using setup.sh only, " diff --git a/nemo/deploy/multimodal/query_multimodal.py b/nemo/deploy/multimodal/query_multimodal.py index 1c01c6861048..63e6a3e8c3a6 100644 --- a/nemo/deploy/multimodal/query_multimodal.py +++ b/nemo/deploy/multimodal/query_multimodal.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import soundfile as sf from PIL import Image from nemo.deploy.utils import str_list2numpy @@ -71,6 +72,11 @@ def setup_media(self, input_media): elif self.model_type == "neva" or self.model_type == "vila": media = Image.open(input_media).convert('RGB') return np.expand_dims(np.array(media), axis=0) + elif self.model_type == "salm": + waveform, sample_rate = sf.read(input_media, dtype=np.float32) + input_signal = np.array([waveform], dtype=np.float32) + input_signal_length = np.array([[len(waveform)]], dtype=np.int32) + return {"input_signal": input_signal, "input_signal_length": input_signal_length} else: raise RuntimeError(f"Invalid model type {self.model_type}") @@ -105,8 +111,10 @@ def query( inputs = {"input_text": prompts} media = self.setup_media(input_media) - - inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0) + if isinstance(media, dict): + inputs.update(media) + else: + inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0) if batch_size is not None: inputs["batch_size"] = np.full(prompts.shape, batch_size, dtype=np.int_) diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py index 8ee3fa1c05e7..53c598be47c6 100644 --- a/nemo/export/multimodal/build.py +++ b/nemo/export/multimodal/build.py @@ -23,9 +23,12 @@ import tensorrt as trt import torch import yaml +from omegaconf import OmegaConf from tensorrt_llm.builder import Builder from transformers import AutoModel +from nemo.collections.multimodal.speech_llm.modules.perception_modules import AudioPerceptionModule +from nemo.core.classes.common import typecheck from nemo.export.tensorrt_llm import TensorRTLLM from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import load_nemo_model @@ -76,6 +79,32 @@ def export_visual_wrapper_onnx( ) +def export_perception_wrapper_onnx( + perception_wrapper, + input, + output_dir, + input_names=['processed_signal', 'processed_signal_length'], + output_names=['encoded', 'encoded_length'], + dynamic_axes={ + 'processed_signal': {0: 'batch', 2: 'time'}, + 'processed_signal_length': {0: 'batch'}, + 'encoded': {0: 'batch', 1: 'time'}, + 'encoded_length': {0: 'batch'}, + }, +): + logger.log(trt.Logger.INFO, "Exporting onnx") + os.makedirs(f'{output_dir}/onnx', exist_ok=True) + torch.onnx.export( + perception_wrapper, + input, + f'{output_dir}/onnx/perception_encoder.onnx', + opset_version=17, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + def build_trt_engine( model_type, input_sizes, @@ -85,8 +114,8 @@ def build_trt_engine( image_size=None, num_frames=None, nemo_config=None, + part_name='visual_encoder', ): - part_name = 'visual_encoder' onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name) engine_file = '%s/%s.engine' % (output_dir, part_name) config_file = '%s/%s' % (output_dir, "config.json") @@ -131,6 +160,10 @@ def build_trt_engine( # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images, # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]). + # or a list of three list of lists + # (e.g., [{input1: min_shape, input2: min_shape, }, \ + # {input1: opt_shape, input2: opt_shape}, \ + # {input1: max_shape, input2: max_shape}] ) assert isinstance(input_sizes, list), "input_sizes must be a list" if isinstance(input_sizes[0], int): logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}") @@ -139,10 +172,23 @@ def build_trt_engine( elif len(input_sizes) == 3 and isinstance(input_sizes[0], list): min_size, opt_size, max_size = input_sizes logger.log(trt.Logger.INFO, f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}") + elif len(input_sizes) == 3 and isinstance(input_sizes[0], dict): + logger.log(trt.Logger.INFO, f"Processed min/opt/max input sizes {input_sizes}") else: raise ValueError(f"invalid input sizes: {input_sizes}") - profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size]) + if isinstance(input_sizes[0], dict): + for i in range(network.num_inputs): + inputT = network.get_input(i) + input_name = inputT.name + min_size = input_sizes[0][input_name] + opt_size = input_sizes[1][input_name] + max_size = input_sizes[2][input_name] + logger.log(trt.Logger.INFO, f"{input_name} min/opt/max input sizes {min_size}/{opt_size}/{max_size}") + profile.set_shape(input_name, min_size, opt_size, max_size) + else: + profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size]) + config.add_optimization_profile(profile) t0 = time() @@ -367,6 +413,76 @@ def forward(self, images): ) +def build_perception_engine( + model_dir: str, + perception_checkpoint_path: str, + model_type: str = "salm", + max_batch_size: int = 1, +): + assert model_type == "salm", f"Invalid model type {model_type}" + + def load_perception_model(perception_checkpoint_path): + weights = "model_weights.ckpt" + perception_state_dict = torch.load(os.path.join(perception_checkpoint_path, weights)) + config = "model_config.yaml" + config = OmegaConf.load(os.path.join(perception_checkpoint_path, config)) + perception = AudioPerceptionModule(cfg=config) + perception.load_state_dict(perception_state_dict) + perception.eval() + return perception + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + # load perception model + perception_model = load_perception_model(perception_checkpoint_path) + feature_extractor = perception_model.preprocessor + input_signal = torch.randn(1, 1000, dtype=torch.float32) + input_signal_length = torch.tensor([1000], dtype=torch.int32) + + processed_signal, processed_signal_length = feature_extractor( + input_signal=input_signal, length=input_signal_length + ) + processed_signal_length = processed_signal_length.to(torch.int32) + dump_path = model_dir + "/feature_extractor.ts" # dump the feature extractor as torchscript + feature_extractor.export(dump_path, (input_signal, input_signal_length)) + + class PerceptionWrapper(torch.nn.Module): + def __init__(self, encoder, modality_adapter, proj): + super().__init__() + self.encoder = encoder + self.modality_adapter = modality_adapter + self.proj = proj + + @typecheck.disable_checks() + def forward(self, processed_signal, processed_signal_length): + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + encoded, encoded_len = self.modality_adapter(audio_signal=encoded, length=encoded_len) + # b, c, t -> b, t, c + encoded = self.proj(encoded.transpose(1, 2)) + encoded_len = encoded_len.to(torch.int32) + return encoded, encoded_len + + perception = PerceptionWrapper(perception_model.encoder, perception_model.modality_adapter, perception_model.proj) + export_perception_wrapper_onnx(perception, (processed_signal, processed_signal_length), model_dir) + # export the onnx perception model to tensorrt engine + # 512 -> 5.12 sec, 3072 -> 30.72 sec + opt_batch_size = max(1, max_batch_size // 2) + shapes = [ + {"processed_signal": [1, 80, 64], "processed_signal_length": [1]}, + {"processed_signal": [opt_batch_size, 80, 512], "processed_signal_length": [opt_batch_size]}, + {"processed_signal": [max_batch_size, 80, 3072], "processed_signal_length": [max_batch_size]}, + ] + build_trt_engine( + model_type, + shapes, + model_dir, + max_batch_size, + dtype=torch.float16, + nemo_config=None, + part_name='perception_encoder', + ) + + def build_visual_engine( model_dir: str, visual_checkpoint_path: str, diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py index 149df995c77a..2cde46ca41fa 100644 --- a/nemo/export/multimodal/run.py +++ b/nemo/export/multimodal/run.py @@ -25,6 +25,7 @@ import einops import numpy as np +import soundfile as sf import tensorrt as trt import tensorrt_llm import tensorrt_llm.profiler as profiler @@ -32,7 +33,7 @@ import yaml from PIL import Image from tensorrt_llm import logger -from tensorrt_llm._utils import str_dtype_to_trt +from tensorrt_llm._utils import str_dtype_to_trt, torch_dtype_to_trt from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo from torch.nn import functional as F from torchvision import transforms @@ -54,7 +55,8 @@ def trt_dtype_to_torch(dtype): class MultimodalModelRunner: - def __init__(self, visual_engine_dir, llm_engine_dir): + def __init__(self, visual_engine_dir, llm_engine_dir, modality='vision'): + self.modality = modality self.runtime_rank = tensorrt_llm.mpi_rank() device_id = self.runtime_rank % torch.cuda.device_count() torch.cuda.set_device(device_id) @@ -68,13 +70,15 @@ def __init__(self, visual_engine_dir, llm_engine_dir): config = json.load(f) self.model_type = config['builder_config']['model_type'] self.vision_precision = config['builder_config']['precision'] + self.modality_precision = config['builder_config']['precision'] self.num_frames = config['builder_config'].get('num_frames', None) self.image_size = config['builder_config'].get('image_size', None) self.profiling_iterations = 20 - self.init_image_encoder(visual_engine_dir) + if modality == 'vision': + self.init_image_encoder(visual_engine_dir) self.init_tokenizer(llm_engine_dir) self.init_llm(llm_engine_dir) if self.model_type == 'lita' or self.model_type == 'vila' or self.model_type == 'vita': @@ -242,10 +246,10 @@ def insert_tokens_by_index(self, input_ids, num_frames): def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, batch_size): if not warmup: - profiler.start("Vision") + profiler.start(self.modality.capitalize()) if not warmup: - profiler.stop("Vision") + profiler.stop(self.modality.capitalize()) if self.model_type == 'vila': visual_features, visual_atts = self.get_visual_features(image, attention_mask) @@ -848,7 +852,7 @@ def print_result(self, input_text, output_text, batch_size, num_beams, run_profi if run_profiling: msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(name) / self.profiling_iterations logger.info('Latencies per batch (msec)') - logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision'))) + logger.info(f'TRT {self.modality} encoder: %.1f' % (msec_per_batch(self.modality.capitalize()))) logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM'))) logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate'))) @@ -864,3 +868,278 @@ def load_test_media(self, input_media): raise RuntimeError(f"Invalid model type {self.model_type}") return media + + +class SpeechllmModelRunner(MultimodalModelRunner): + def __init__(self, perception_engine_dir, llm_engine_dir, modality): + """ + perception_engine_dir: path to the perception engine directory + it should contain: + config.json nemo_config.yaml + perception_encoder.engine : tensorrt engine + feature_extractor.ts : torchscript model + llm_engine_dir: path to the LLM engine directory + """ + super().__init__(perception_engine_dir, llm_engine_dir, modality) + assert self.model_type == 'salm' + # init preprocessor + feature_extractor_path = os.path.join(perception_engine_dir, 'feature_extractor.ts') + self.feature_extractor = self.init_speech_preprocessor(feature_extractor_path) + self.init_modality_encoder(perception_engine_dir) + + def init_modality_encoder(self, engine_dir): + """ + Initialize the modality encoder session from the prebuilt engine directory + Args: + engine_dir: str, path to the engine directory + """ + # find file with .engine extension + engine_file = None + for file in os.listdir(engine_dir): + if file.endswith('.engine'): + engine_file = file + break + assert engine_file is not None, f"Engine file not found in {engine_dir}" + encoder_path = os.path.join(engine_dir, engine_file) + logger.info(f'Loading engine from {encoder_path}') + with open(encoder_path, 'rb') as f: + engine_buffer = f.read() + logger.info(f'Creating session from engine {encoder_path}') + self.modality_encoder_session = Session.from_serialized_engine(engine_buffer) + + def init_speech_preprocessor(self, feature_extractor_path): + feature_extractor = torch.jit.load(feature_extractor_path) + feature_extractor.eval() + return feature_extractor + + def process_audio(self, input_signal, input_signal_length): + """ + Args: + input_signal: audio signal in numpy array + input_signal_length: length of the audio signal in numpy array + + Returns: + processed_signal: torch.tensor [B, 80, T] + processed_signal_length [B] + """ + input_signal = torch.tensor(input_signal, dtype=torch.float32) + input_signal_length = torch.tensor(input_signal_length, dtype=torch.int32) + processed_signal, processed_signal_length = self.feature_extractor(input_signal, input_signal_length) + return processed_signal, processed_signal_length + + def setup_inputs(self, input_text, input_media, batch_size): + """ + Args: + input_text: str or List[str] or None + input_media: Tuple[np.array, np.array] + input_signal: audio signal in numpy array [b, -1] + input_signal_length: length of the audio signal in numpy array [b] + batch_size: int + + """ + input_signal, input_signal_length = input_media + processed_signal, processed_signal_length = self.process_audio(input_signal, input_signal_length) + processed_signal = processed_signal.to(self.device) + processed_signal_length = processed_signal_length.to(self.device) + if input_text is None: + input_text = "Q: what's the transcription of the audio? A:" + + if isinstance(input_text, str): + input_text = [input_text] * batch_size + + assert len(input_text) == batch_size + pre_prompt = [''] * batch_size + post_prompt = input_text + decoder_input_ids = None + attention_mask = None + return ( + input_text, + pre_prompt, + post_prompt, + processed_signal, + processed_signal_length, + decoder_input_ids, + attention_mask, + ) + + def load_test_media(self, input_media_path): + """ + Args: + input_media_path: str, path to the audio file + Returns: + input_signal: np.array [1, -1] + input_signal_length: np.array [1] + """ + waveform, sample_rate = sf.read(input_media_path, dtype=np.float32) + input_signal = np.array([waveform], dtype=np.float32) + input_signal_length = np.array([len(waveform)], dtype=np.int32) + return input_signal, input_signal_length + + def get_modality_encoder_features(self, modality_features, attention_mask): + """ + Do inference on the modality encoder engine + Args: + modality_features: dict {'input1': torch.tensor, 'input2': torch.tensor, ..} + attention_mask: None + Returns: + """ + + if attention_mask is not None: + modality_features['attention_mask'] = attention_mask + + tensor_info = [] + for key, tensor in modality_features.items(): + tensor_info.append(TensorInfo(key, torch_dtype_to_trt(tensor.dtype), tensor.shape)) + + output_info = self.modality_encoder_session.infer_shapes(tensor_info) + + outputs = { + t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=self.device) + for t in output_info + } + + ok = self.modality_encoder_session.run(modality_features, outputs, self.stream.cuda_stream) + assert ok, "Runtime execution failed for vision encoder session" + self.stream.synchronize() + + return outputs + + def preprocess(self, warmup, pre_prompt, post_prompt, processed_features, attention_mask, batch_size): + """ + Args: + warmup: bool + pre_prompt: List[str] + post_prompt: List[str] + processed_features: Tuple[torch.tensor, torch.tensor] + processed_signal: torch.tensor [B, 80, T] + processed_signal_length: torch.tensor [B] + attention_mask: None + batch_size: int + Returns: + input_ids: torch.tensor [B, L] + input_lengths: torch.tensor [B] + ptuning_args: List[torch.tensor] + encoded_features: torch.tensor [B, L, D] + """ + if not warmup: + profiler.start(self.modality.capitalize()) + + if not warmup: + profiler.stop(self.modality.capitalize()) + + assert self.model_type == 'salm', f"Invalid model type {self.model_type}" + + processed_features = { + "processed_signal": processed_features[0], + "processed_signal_length": processed_features[1].to(torch.int32), + } + encoded_outputs = self.get_modality_encoder_features(processed_features, attention_mask) + encoded_features, encoded_length = encoded_outputs['encoded'], encoded_outputs['encoded_length'] + pre_input_ids = self.tokenizer(pre_prompt).input_ids + post_input_ids = self.tokenizer(post_prompt).input_ids + input_lengths = [] + input_ids = [] + encoded_length = encoded_length.cpu().numpy() + fake_id_start = self.model.vocab_size + for i in range(batch_size): + feat_len = encoded_length[i] + feat_fake_ids = np.arange(fake_id_start, fake_id_start + feat_len) + cur_input_ids = np.concatenate([pre_input_ids[i], feat_fake_ids, post_input_ids[i]]) + fake_id_start += feat_len + input_lengths.append(len(cur_input_ids)) + input_ids.append(cur_input_ids) + + max_length = max(input_lengths) + # convert input_ids to torch tensor with padding + input_ids = [ + np.pad(ids, (0, max_length - len(ids)), 'constant', constant_values=self.tokenizer.pad_token_id) + for ids in input_ids + ] + input_ids = torch.tensor(input_ids, dtype=torch.int32) + input_lengths = torch.tensor(input_lengths, dtype=torch.int32) + ptuning_args = self.ptuning_setup(encoded_features, input_ids, input_lengths) + + return input_ids, input_lengths, ptuning_args, encoded_features + + def run( + self, + input_text, + input_media=None, + max_new_tokens: int = 30, + batch_size: int = 1, + top_k: int = 1, + top_p: float = 0.0, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + num_beams: int = 1, + run_profiling=False, + check_accuracy=False, + input_signal=None, + input_signal_length=None, + ): + """ + Args: + input_text: str or List[str] or None + input_media: Tuple[np.array, np.array] or None + input_signal: audio signal in numpy array [b, -1] + input_signal_length: length of the audio signal in numpy array [b] + max_new_tokens: int + batch_size: int + top_k: int + top_p: float + temperature: float + repetition_penalty: float + num_beams: int + run_profiling: bool + check_accuracy: bool + """ + if input_media is None: + assert input_signal is not None and input_signal_length is not None + input_media = (input_signal, input_signal_length) + + ( + input_text, + pre_prompt, + post_prompt, + processed_signal, + processed_signal_length, + decoder_input_ids, + attention_mask, + ) = self.setup_inputs(input_text, input_media, batch_size) + processed_media = (processed_signal, processed_signal_length) + + self.generate( + pre_prompt, + post_prompt, + processed_media, + decoder_input_ids, + max_new_tokens, + attention_mask=attention_mask, + warmup=True, + batch_size=batch_size, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + num_beams=num_beams, + ) + num_iters = self.profiling_iterations if run_profiling else 1 + for _ in range(num_iters): + output_text = self.generate( + pre_prompt, + post_prompt, + processed_media, + decoder_input_ids, + max_new_tokens, + attention_mask=attention_mask, + warmup=False, + batch_size=batch_size, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + num_beams=num_beams, + ) + if self.runtime_rank == 0: + self.print_result(input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy) + return output_text diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 3c73da1c0731..2a89b76cc099 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -100,6 +100,7 @@ def __init__( use_python_runtime: bool = True, enable_chunked_context: bool = None, max_tokens_in_paged_kv_cache: int = None, + multi_block_mode: bool = False, ): """ Args: @@ -107,6 +108,7 @@ def __init__( lora_ckpt_list (List[str]): lora checkpoint paths. load_model (bool): load TensorRT-LLM model if the engine files exist in the model_dir. use_python_runtime (bool): whether to use python or c++ runtime. + multi_block_mode (bool): enable faster decoding in multihead attention. Required for long context. Only available when using c++ runtime """ if use_python_runtime: @@ -122,6 +124,7 @@ def __init__( self.use_python_runtime = use_python_runtime self.enable_chunked_context = enable_chunked_context if enable_chunked_context is not None else False self.max_tokens_in_paged_kv_cache = max_tokens_in_paged_kv_cache + self.multi_block_mode = multi_block_mode self.model = None self.tokenizer = None self.n_gpus = None @@ -157,7 +160,6 @@ def export( paged_context_fmha: bool = False, dtype: str = "bfloat16", load_model: bool = True, - enable_multi_block_mode: bool = False, use_lora_plugin: str = None, lora_target_modules: List[str] = None, max_lora_rank: int = 64, @@ -192,7 +194,6 @@ def export( remove_input_padding (bool): enables removing input padding or not. dtype (str): Floating point type for model weights (Supports BFloat16/Float16). load_model (bool): load TensorRT-LLM model after the export. - enable_multi_block_mode (bool): enable faster decoding in multihead attention. Required for long context. use_lora_plugin (str): use dynamic lora or not. lora_target_modules (List[str]): list of the target lora modules. max_lora_rank (int): maximum lora rank. @@ -288,7 +289,6 @@ def export( use_parallel_embedding=use_parallel_embedding, paged_kv_cache=paged_kv_cache, remove_input_padding=remove_input_padding, - enable_multi_block_mode=enable_multi_block_mode, use_lora_plugin=use_lora_plugin, lora_target_modules=lora_target_modules, max_lora_rank=max_lora_rank, @@ -340,7 +340,6 @@ def export( max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, max_prompt_embedding_table_size=max_prompt_embedding_table_size, - enable_multi_block_mode=enable_multi_block_mode, paged_kv_cache=paged_kv_cache, remove_input_padding=remove_input_padding, paged_context_fmha=paged_context_fmha, @@ -960,6 +959,7 @@ def _load(self): use_python_runtime=self.use_python_runtime, enable_chunked_context=self.enable_chunked_context, max_tokens_in_paged_kv_cache=self.max_tokens_in_paged_kv_cache, + multi_block_mode=self.multi_block_mode, ) self._load_prompt_tables() except Exception as error: diff --git a/nemo/export/tensorrt_mm_exporter.py b/nemo/export/tensorrt_mm_exporter.py index b0536a55f95f..d4da0ac34b1c 100644 --- a/nemo/export/tensorrt_mm_exporter.py +++ b/nemo/export/tensorrt_mm_exporter.py @@ -21,8 +21,8 @@ import wrapt from nemo.deploy import ITritonDeployable -from nemo.export.multimodal.build import build_trtllm_engine, build_visual_engine -from nemo.export.multimodal.run import MultimodalModelRunner +from nemo.export.multimodal.build import build_perception_engine, build_trtllm_engine, build_visual_engine +from nemo.export.multimodal.run import MultimodalModelRunner, SpeechllmModelRunner use_deploy = True try: @@ -74,9 +74,13 @@ def __init__( self, model_dir: str, load_model: bool = True, + modality: str = "vision", ): self.model_dir = model_dir self.runner = None + # vision modality is for image and video + assert modality in ["vision", "audio"] + self.modality = modality if load_model: self._load() @@ -128,8 +132,12 @@ def export( dtype=dtype, ) - visual_dir = os.path.join(self.model_dir, "visual_engine") - build_visual_engine(visual_dir, visual_checkpoint_path, model_type, vision_max_batch_size) + if model_type == "salm": + perception_dir = os.path.join(self.model_dir, "perception_engine") + build_perception_engine(perception_dir, visual_checkpoint_path, model_type, vision_max_batch_size) + else: + visual_dir = os.path.join(self.model_dir, "visual_engine") + build_visual_engine(visual_dir, visual_checkpoint_path, model_type, vision_max_batch_size) if load_model: self._load() @@ -164,19 +172,32 @@ def forward( num_beams, ) + def get_input_media_tensors(self): + if self.modality == "vision": + return [Tensor(name="input_media", shape=(-1, -1, -1, 3), dtype=np.uint8)] + elif self.modality == "audio": + return [ + Tensor(name="input_signal", shape=(-1,), dtype=np.single), + Tensor(name="input_signal_length", shape=(1,), dtype=np.intc), + ] + return [] + @property def get_triton_input(self): inputs = ( - Tensor(name="input_text", shape=(-1,), dtype=bytes), - Tensor(name="input_media", shape=(-1, -1, -1, 3), dtype=np.uint8), - Tensor(name="batch_size", shape=(-1,), dtype=np.int_, optional=True), - Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True), - Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), - Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), - Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), - Tensor(name="repetition_penalty", shape=(-1,), dtype=np.single, optional=True), - Tensor(name="num_beams", shape=(-1,), dtype=np.int_, optional=True), + [Tensor(name="input_text", shape=(-1,), dtype=bytes)] + + self.get_input_media_tensors() + + [ + Tensor(name="batch_size", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="repetition_penalty", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="num_beams", shape=(-1,), dtype=np.int_, optional=True), + ] ) + inputs = tuple(inputs) return inputs @property @@ -198,6 +219,9 @@ def triton_infer_fn(self, **inputs: np.ndarray): infer_input["input_image"] = ndarray2img(inputs.pop("input_media")[0])[0] elif self.runner.model_type in video_model_list: infer_input["input_image"] = inputs.pop("input_media")[0] + elif self.runner.model_type == "salm": + infer_input["input_signal"] = inputs.pop("input_signal") + infer_input["input_signal_length"] = inputs.pop("input_signal_length")[:, 0] if "batch_size" in inputs: infer_input["batch_size"] = inputs.pop("batch_size")[0][0] if "max_output_len" in inputs: @@ -223,5 +247,9 @@ def triton_infer_fn(self, **inputs: np.ndarray): def _load(self): llm_dir = os.path.join(self.model_dir, "llm_engine") - visual_dir = os.path.join(self.model_dir, "visual_engine") - self.runner = MultimodalModelRunner(visual_dir, llm_dir) + if self.modality == "vision": + visual_dir = os.path.join(self.model_dir, "visual_engine") + self.runner = MultimodalModelRunner(visual_dir, llm_dir, self.modality) + elif self.modality == "audio": + perception_dir = os.path.join(self.model_dir, "perception_engine") + self.runner = SpeechllmModelRunner(perception_dir, llm_dir, self.modality) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index 60d50316e9ed..337a0a4e4e77 100755 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -22,8 +22,6 @@ from tensorrt_llm._utils import pad_vocab_size from tensorrt_llm.functional import non_gated_version from tensorrt_llm.layers import MoeConfig -from tensorrt_llm.models.gpt.config import GPTConfig -from tensorrt_llm.models.llama.config import LLaMAConfig from tensorrt_llm.models.modeling_utils import PretrainedConfig from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import ( @@ -36,12 +34,16 @@ def get_config(decoder_type, config): - if decoder_type == "llama": - return LLaMAConfig(**config) - elif decoder_type == "gpt" or decoder_type == "gptnext": - return GPTConfig(**config) - else: - return PretrainedConfig(**config) + DECODER_CONFIG = { + "llama": tensorrt_llm.models.llama.config.LLaMAConfig, + "gpt": tensorrt_llm.models.gpt.config.GPTConfig, + "gptnext": tensorrt_llm.models.gpt.config.GPTConfig, + "falcon": tensorrt_llm.models.falcon.config.FalconConfig, + "gemma": tensorrt_llm.models.GemmaConfig, + } + config_cls = DECODER_CONFIG[decoder_type] if decoder_type in DECODER_CONFIG else PretrainedConfig + + return config_cls(**config) def prompt_convert(prompt_config, prompt_weights): diff --git a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py index 921c6535a57a..48127a507a58 100644 --- a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py +++ b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py @@ -36,7 +36,6 @@ def qnemo_to_tensorrt_llm( use_parallel_embedding: bool = False, paged_kv_cache: bool = True, remove_input_padding: bool = True, - enable_multi_block_mode: bool = False, use_lora_plugin: Optional[str] = None, lora_target_modules: Optional[List[str]] = None, max_lora_rank: int = 64, @@ -93,7 +92,6 @@ def qnemo_to_tensorrt_llm( build_cmd += f"--nccl_plugin {config.dtype} " build_cmd += f"--paged_kv_cache {'enable' if paged_kv_cache else 'disable'} " build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} " - build_cmd += f"--multi_block_mode {'enable' if enable_multi_block_mode else 'disable'} " build_cmd += f"--multiple_profiles {'enable' if multiple_profiles else 'disable'} " if use_fused_mlp: diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 1544fdf032d8..e37c3ba1c845 100755 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -41,11 +41,9 @@ def build_and_save_engine( max_lora_rank=64, lora_target_modules=None, max_prompt_embedding_table_size=0, - enable_multi_block_mode: bool = False, paged_kv_cache: bool = True, remove_input_padding: bool = True, paged_context_fmha: bool = False, - use_custom_all_reduce: bool = True, use_refit: bool = False, max_num_tokens: int = None, max_seq_len: int = None, @@ -66,8 +64,6 @@ def build_and_save_engine( plugin_config = PluginConfig() plugin_config.gpt_attention_plugin = gpt_attention_plugin plugin_config.gemm_plugin = gemm_plugin - plugin_config.set_nccl_plugin(use_custom_all_reduce=use_custom_all_reduce) - plugin_config.multi_block_mode = enable_multi_block_mode if paged_kv_cache: plugin_config.enable_paged_kv_cache(tokens_per_block=tokens_per_block) else: diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 14ad0be699bb..852eddc6a468 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -144,6 +144,7 @@ def _load( use_python_runtime: bool = True, enable_chunked_context: bool = False, max_tokens_in_paged_kv_cache: int = None, + multi_block_mode: bool = False, ): """The impl of `load` API for on a single GPU worker.""" try: @@ -164,6 +165,11 @@ def _load( runtime_rank = tensorrt_llm.mpi_rank() if use_python_runtime: + if enable_chunked_context: + logging.warning("enable_chunked_context is disabled when using python runtime") + if multi_block_mode: + logging.warning("multi_block_mode is disabled when using python runtime") + decoder = ModelRunner.from_dir( engine_dir=engine_dir, lora_dir=lora_ckpt_list, @@ -183,6 +189,7 @@ def _load( max_beam_width=max_beam_width, enable_chunked_context=enable_chunked_context, max_tokens_in_paged_kv_cache=max_tokens_in_paged_kv_cache, + multi_block_mode=multi_block_mode, debug_mode=False, ) @@ -296,6 +303,7 @@ def load( use_python_runtime: bool = True, enable_chunked_context: bool = False, max_tokens_in_paged_kv_cache: int = None, + multi_block_mode: bool = False, ) -> TensorrtLLMHostContext: """Loaded the compiled LLM model and run it. @@ -315,6 +323,7 @@ def load( use_python_runtime, enable_chunked_context, max_tokens_in_paged_kv_cache, + multi_block_mode, ) executor = None elif tensorrt_llm.mpi_world_size() > 1: diff --git a/nemo/lightning/README.md b/nemo/lightning/README.md new file mode 100644 index 000000000000..7b9266d3fa30 --- /dev/null +++ b/nemo/lightning/README.md @@ -0,0 +1,13 @@ +# NeMo Lightning + +The NeMo Lightning directory provides custom PyTorch Lightning-compatible objects for seamlessly training NeMo 2.0 models using PTL. NeMo 2.0 models +are implemented using [Megatron Core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). NeMo Lightning provides the bridge between higher-level, object-oriented PTL APIs and lower-level Megatron APIs. +For detailed tutorials and documentation on NeMo 2.0, refer to the [docs](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo_2.0/index.html). + +Some of the helpful classes provided here include: +- [`Trainer`](./pytorch/trainer.py): A lightweight wrapper around PTL's `Trainer` object which provides some additional support for capturing the arguments used to initialized the trainer. More information on NeMo 2's serialization mechanisms is available [here](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo_2.0/design/serialization.html). +- [`MegatronStrategy`](./pytorch/strategies.py): A PTL strategy that enables training of Megatron models on NVIDIA GPUs. +- [`MegatronParallel`](./megatron_parallel.py): Class which sets up and manages Megatron's distributed model parallelism. +- [`MegatronMixedPrecision`](./pytorch/plugins/mixed_precision.py): A specialized precision plugin for training Megatron-based models in PTL. + +More information on `MegatronStrategy`, `MegatronParallel`, and `MegatronMixedPrecision` can be found in [this document](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo_2.0/design/megatron.html). diff --git a/nemo/lightning/io/api.py b/nemo/lightning/io/api.py index 4d31f020c44a..4315b3211bf7 100644 --- a/nemo/lightning/io/api.py +++ b/nemo/lightning/io/api.py @@ -1,61 +1,13 @@ -import json from pathlib import Path -from pydoc import locate from typing import Any, Callable, Optional, Type, TypeVar import fiddle as fdl import pytorch_lightning as pl from fiddle._src.experimental import serialization -from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, track_io +from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, load from nemo.lightning.io.pl import TrainerContext -CkptType = TypeVar("CkptType") - - -def load(path: Path, output_type: Type[CkptType] = Any) -> CkptType: - """ - Loads a configuration from a pickle file and constructs an object of the specified type. - - Args: - path (Path): The path to the pickle file or directory containing 'io.pkl'. - output_type (Type[CkptType]): The type of the object to be constructed from the loaded data. - - Returns - ------- - CkptType: An instance of the specified type constructed from the loaded configuration. - - Raises - ------ - FileNotFoundError: If the specified file does not exist. - - Example: - loaded_model = load("/path/to/model", output_type=MyModel) - """ - del output_type # Just for type-hint - - _path = Path(path) - if hasattr(_path, 'is_dir') and _path.is_dir(): - _path = Path(_path) / "io.json" - elif hasattr(_path, 'isdir') and _path.isdir: - _path = Path(_path) / "io.json" - - if not _path.is_file(): - raise FileNotFoundError(f"No such file: '{_path}'") - - ## add IO functionality to custom objects present in the json file - with open(_path) as f: - j = json.load(f) - for obj, val in j["objects"].items(): - clss = ".".join([val["type"]["module"], val["type"]["name"]]) - if not serialization.find_node_traverser(locate(clss)): - track_io(locate(clss)) - - with open(_path, "rb") as f: - config = serialization.load_json(f.read()) - - return fdl.build(config) - def load_context(path: Path) -> TrainerContext: """ diff --git a/nemo/lightning/io/artifact/base.py b/nemo/lightning/io/artifact/base.py index 9119b2474b17..a997df42f843 100644 --- a/nemo/lightning/io/artifact/base.py +++ b/nemo/lightning/io/artifact/base.py @@ -11,7 +11,7 @@ def __init__(self, attr: str, required: bool = True): self.required = required @abstractmethod - def dump(self, value: ValueT, path: Path) -> ValueT: + def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT: pass @abstractmethod diff --git a/nemo/lightning/io/artifact/file.py b/nemo/lightning/io/artifact/file.py index 0bd4f48dc17f..76bd0c6003a6 100644 --- a/nemo/lightning/io/artifact/file.py +++ b/nemo/lightning/io/artifact/file.py @@ -6,8 +6,8 @@ class PathArtifact(Artifact[Path]): - def dump(self, value: Path, path: Path) -> Path: - new_value = copy_file(value, path) + def dump(self, value: Path, absolute_dir: Path, relative_dir: Path) -> Path: + new_value = copy_file(value, absolute_dir, relative_dir) return new_value def load(self, path: Path) -> Path: @@ -15,15 +15,16 @@ def load(self, path: Path) -> Path: class FileArtifact(Artifact[str]): - def dump(self, value: str, path: Path) -> str: - new_value = copy_file(value, path) + def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str: + new_value = copy_file(value, absolute_dir, relative_dir) return str(new_value) def load(self, path: str) -> str: return path -def copy_file(src: Union[Path, str], dst: Union[Path, str]): - output = Path(dst) / Path(src).name +def copy_file(src: Union[Path, str], path: Union[Path, str], relative_dst: Union[Path, str]): + relative_path = Path(relative_dst) / Path(src).name + output = Path(path) / relative_path shutil.copy2(src, output) - return output + return relative_path diff --git a/nemo/lightning/io/artifact/pickle.py b/nemo/lightning/io/artifact/pickle.py index 31ed7e36ac93..61a9c82237fc 100644 --- a/nemo/lightning/io/artifact/pickle.py +++ b/nemo/lightning/io/artifact/pickle.py @@ -7,12 +7,12 @@ class PickleArtifact(Artifact[Any]): - def dump(self, value: Any, path: Path) -> Path: - file = self.file_path(path) - with open(file, "wb") as f: + def dump(self, absolute_dir: Path, relative_dir: Path) -> Path: + relative_file = self.file_path(relative_dir) + with open(Path(absolute_dir) / relative_file, "wb") as f: dump(value, f) - return file + return relative_file def load(self, path: Path) -> Any: with open(self.file_path(path), "rb") as f: diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index 69368599682e..512f3bc4f12e 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -145,6 +145,7 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] = pl.Trainer: The trainer configured with the model and strategy. """ from nemo.lightning import MegatronStrategy, Trainer + from nemo.lightning._strategy_lib import megatron_lazy_init_context _trainer = trainer or Trainer( devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False) @@ -155,7 +156,7 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] = if not model.state_dict(): _trainer.strategy.lazy_init = True - with _trainer.init_module(): + with _trainer.init_module(), megatron_lazy_init_context(model.config): model.configure_model() return _trainer diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index d0d4d0243ff7..e249e2e318b6 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -1,5 +1,6 @@ import functools import inspect +import json import shutil import threading import types @@ -7,11 +8,13 @@ from copy import deepcopy from dataclasses import is_dataclass from pathlib import Path +from pydoc import locate from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union import fiddle as fdl import fiddle._src.experimental.dataclasses as fdl_dc -from cloudpickle import dump, load +from cloudpickle import dump +from cloudpickle import load as pickle_load from fiddle._src.experimental import serialization from typing_extensions import Self @@ -21,6 +24,7 @@ from nemo.lightning.io.fdl_torch import enable as _enable_ext ConnT = TypeVar('ConnT', bound=ModelConnector) +CkptType = TypeVar("CkptType") _enable_ext() @@ -136,21 +140,24 @@ def io_dump(self, output: Path): will be stored. """ output_path = Path(output) - artifacts_dir = output_path / "artifacts" + local_artifacts_dir = "artifacts" + artifacts_dir = output_path / local_artifacts_dir artifacts_dir.mkdir(parents=True, exist_ok=True) # Store artifacts directory in thread-local storage - _thread_local.artifacts_dir = artifacts_dir + _thread_local.local_artifacts_dir = local_artifacts_dir + _thread_local.output_path = output_path config_path = output_path / "io.json" with open(config_path, "w") as f: io = deepcopy(self.__io__) - _artifact_transform(io, artifacts_dir) + _artifact_transform_save(io, output_path, local_artifacts_dir) json = serialization.dump_json(io) f.write(json) # Clear thread-local storage after io_dump is complete - del _thread_local.artifacts_dir + del _thread_local.local_artifacts_dir + del _thread_local.output_path # Check if artifacts directory is empty and delete if so if not any(artifacts_dir.iterdir()): @@ -293,13 +300,8 @@ def import_ckpt(self, path: str, overwrite: bool = False, base_path: Optional[Pa """ connector = self._get_connector(path) ckpt_path: Path = connector.local_path(base_path=base_path) - # If already in multiproc environment (e.g. due to torchrun invocation) run only on RANK = 0 - from nemo.utils.get_rank import is_global_rank_zero - - if is_global_rank_zero(): - ckpt_path = connector(ckpt_path, overwrite=overwrite) - connector.on_import_ckpt(self) - + ckpt_path = connector(ckpt_path, overwrite=overwrite) + connector.on_import_ckpt(self) return ckpt_path @classmethod @@ -481,23 +483,28 @@ def _io_flatten_object(instance): try: serialization.dump_json(instance.__io__) except (serialization.UnserializableValueError, AttributeError) as e: - if not hasattr(_thread_local, "artifacts_dir"): + if not hasattr(_thread_local, "local_artifacts_dir") or not hasattr(_thread_local, "output_path"): raise e - artifact_dir = _thread_local.artifacts_dir - artifact_path = artifact_dir / f"{uuid.uuid4()}" + local_artifact_path = Path(_thread_local.local_artifacts_dir) / f"{uuid.uuid4()}" + output_path = _thread_local.output_path + artifact_path = output_path / local_artifact_path with open(artifact_path, "wb") as f: dump(getattr(instance, "__io__", instance), f) - return (str(artifact_path),), None + return (str(local_artifact_path),), None return instance.__io__.__flatten__() def _io_unflatten_object(values, metadata): + + assert hasattr(_thread_local, "output_dir") + output_dir = _thread_local.output_dir + if len(values) == 1: pickle_path = values[0] - with open(pickle_path, "rb") as f: - return load(f) + with open(Path(output_dir) / pickle_path, "rb") as f: + return pickle_load(f) return fdl.Config.__unflatten__(values, metadata) @@ -511,19 +518,82 @@ def _io_path_elements_fn(x): return x.__io__.__path_elements__() -def _artifact_transform(cfg: fdl.Config, output_path: Path): +def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "artifacts"): for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): current_val = getattr(cfg, artifact.attr) if current_val is None: if artifact.required: raise ValueError(f"Artifact '{artifact.attr}' is required but not provided") continue - new_val = artifact.dump(current_val, output_path) + ## dump artifact and return the relative path + new_val = artifact.dump(current_val, output_path, relative_dir) + setattr(cfg, artifact.attr, new_val) + + for attr in dir(cfg): + try: + if isinstance(getattr(cfg, attr), fdl.Config): + _artifact_transform_save(getattr(cfg, attr), output_path=output_path, relative_dir=relative_dir) + except ValueError: + pass + + +def _artifact_transform_load(cfg: fdl.Config, path: Path): + for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): + current_val = getattr(cfg, artifact.attr) + ## replace local path with absolute one + new_val = str(Path(path) / current_val) setattr(cfg, artifact.attr, new_val) for attr in dir(cfg): try: if isinstance(getattr(cfg, attr), fdl.Config): - _artifact_transform(getattr(cfg, attr), output_path=output_path) + _artifact_transform_load(getattr(cfg, attr), path=path) except ValueError: pass + + +def load(path: Path, output_type: Type[CkptType] = Any) -> CkptType: + """ + Loads a configuration from a pickle file and constructs an object of the specified type. + + Args: + path (Path): The path to the pickle file or directory containing 'io.pkl'. + output_type (Type[CkptType]): The type of the object to be constructed from the loaded data. + + Returns + ------- + CkptType: An instance of the specified type constructed from the loaded configuration. + + Raises + ------ + FileNotFoundError: If the specified file does not exist. + + Example: + loaded_model = load("/path/to/model", output_type=MyModel) + """ + del output_type # Just for type-hint + + _path = Path(path) + _thread_local.output_dir = _path + + if hasattr(_path, 'is_dir') and _path.is_dir(): + _path = Path(_path) / "io.json" + elif hasattr(_path, 'isdir') and _path.isdir: + _path = Path(_path) / "io.json" + + if not _path.is_file(): + raise FileNotFoundError(f"No such file: '{_path}'") + + ## add IO functionality to custom objects present in the json file + with open(_path) as f: + j = json.load(f) + for obj, val in j["objects"].items(): + clss = ".".join([val["type"]["module"], val["type"]["name"]]) + if not serialization.find_node_traverser(locate(clss)): + track_io(locate(clss)) + + with open(_path, "rb") as f: + config = serialization.load_json(f.read()) + _artifact_transform_load(config, path) + + return fdl.build(config) diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index d0749fbeead7..f43d24792c1a 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -126,13 +126,22 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) self.validated_consistency = True - return dist_checkpointing.save( - sharded_state_dict=checkpoint, - checkpoint_dir=checkpoint_dir, - sharded_strategy=self.save_sharded_strategy, - validate_access_integrity=validate_sharding_integrity, - async_sharded_save=self.async_save, - ) + + try: + return dist_checkpointing.save( + sharded_state_dict=checkpoint, + checkpoint_dir=checkpoint_dir, + sharded_strategy=self.save_sharded_strategy, + validate_access_integrity=validate_sharding_integrity, + async_sharded_save=self.async_save, + ) + except: + logging.error(f"Failed to save checkpoint to {checkpoint_dir}") + # Do cleanup. + import shutil + + shutil.rmtree(checkpoint_dir) + raise @override def load_checkpoint( diff --git a/nemo/lightning/io/state.py b/nemo/lightning/io/state.py index 9fd81a960358..18e0865171c7 100644 --- a/nemo/lightning/io/state.py +++ b/nemo/lightning/io/state.py @@ -255,7 +255,6 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX: if multiple_sources: for target_index, target_match in np.ndenumerate(target_matches): source_match = source_matches[target_index] - if accepts_var_args: source_values = [source_dict[k] for k in source_match] target_dict[target_match] = self.call_transform(ctx, *source_values) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 56146498b539..dd10a726e67a 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -231,7 +231,24 @@ def forward( pipeline = self.pipeline - use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch' + # FIXME: cleanup the following code block which is here for backwards compatibility with nemo1. The "batch" + # sampler is a nemo1 sampler. It requires some custom code here to use (if use_global_batch_sampler). + # by default we shouldn't use this "batch" sampler probably. + if getattr(self.trainer, "datamodule", None) is not None: + use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch' + elif getattr(self.trainer, "predict_dataloaders", None) is not None: + from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( # noqa: I001 + MegatronPretrainingBatchSampler, + ) + + # The batch_sampler gets injected into the dataloader by the data_sampler. When doing predict without a + # datamodule we can look inside the dataloader's batch_sampler to see if it is the nemo1 style sampler + # that we need to handle specially below. + use_global_batch_sampler = isinstance( + self.trainer.predict_dataloaders.batch_sampler, MegatronPretrainingBatchSampler + ) + else: + raise ValueError("Unsure how to check for nemo1 global_batch_sampler status. TODO maybe default to False?") if use_global_batch_sampler: from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index 5ba2c39f9cff..e5cd45181cc7 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -30,12 +30,16 @@ class NeMoLogger(IOMixin): log_global_rank_0_only (bool): Log only on global rank 0. files_to_copy (Optional[List[str]]): List of files to copy to log directory. update_logger_directory (bool): Whether to update logger directory to write to `exp_dir`. - If True, the `save_dir` passed to the logger will be treated as a relative path and - the logger will be reconfigured to write to `exp_dir / save_dir`. This ensures that - all output from an experiment is written to a common directory. If False, the logger's - save_dir will not be overwritten. This argument applies only to TensorBoardLogger and - WandbLogger instances. + If True, the `save_dir` passed to the logger will be reconfigured to write to `exp_dir / save_dir`. + This ensures that all output from an experiment is written to a common directory. + If False, the logger's save_dir will not be overwritten. + This argument applies only to TensorBoardLogger and WandbLogger instances. ckpt (Optional[ModelCheckpoint]): Model checkpoint callback. + tensorboard: (Optional[TensorBoardLogger]): A PyTorch Lightning TensorBoardLogger instance + to add to the trainer. + wandb (Optional[WandbLogger]): A PyTorch Lightning WandBLogger instance + to add to the trainer. + extra_loggers(Optional[List[Logger]]): Any additional loggers to add to the trainer. """ name: str = "default" @@ -55,7 +59,7 @@ class NeMoLogger(IOMixin): def __post_init__(self): if self.log_local_rank_0_only is True and self.log_global_rank_0_only is True: raise ValueError( - f"Cannot set both log_local_rank_0_only and log_global_rank_0_only to True. Please set either one or neither." + "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True. Please set either one or neither." ) def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = False, task_config=None): @@ -69,7 +73,6 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = AppState: The application state with updated log directory and other settings. """ from nemo.constants import NEMO_ENV_VARNAME_VERSION - from nemo.utils.exp_manager import check_explicit_log_dir from nemo.utils.get_rank import is_global_rank_zero self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) @@ -96,7 +99,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = # Default dir to ./nemo_experiments if None was passed _dir = self.dir if self.dir is None: - _dir = str(Path.cwd() / 'nemo_experiments') + _dir = str(Path.cwd() / "nemo_experiments") if not self.name: self.name = "default" @@ -110,7 +113,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = version = None elif is_global_rank_zero(): if self.use_datetime_version: - version = time.strftime('%Y-%m-%d_%H-%M-%S') + version = time.strftime("%Y-%m-%d_%H-%M-%S") if version: if is_global_rank_zero(): os.environ[NEMO_ENV_VARNAME_VERSION] = version @@ -126,7 +129,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = app_state.cmd_args = sys.argv os.makedirs(log_dir, exist_ok=True) # Cannot limit creation to global zero as all ranks write to own log file - logging.info(f'Experiments will be logged at {log_dir}') + logging.info(f"Experiments will be logged at {log_dir}") if task_config and is_global_rank_zero(): self._handle_task_config(task_config, log_dir) @@ -153,8 +156,7 @@ def _setup_trainer_loggers(self, trainer, dir, version): for logger in trainer.loggers: if isinstance(logger, TensorBoardLogger): logger._version = version or "" - logger._root_dir = Path(dir) / logger.save_dir - trainer.logger._name = self.name + logger._root_dir = Path(dir) / os.path.relpath(logger.save_dir) logging.warning( f'"update_logger_directory" is True. Overwriting tensorboard logger "save_dir" to {logger._root_dir}' ) @@ -162,8 +164,6 @@ def _setup_trainer_loggers(self, trainer, dir, version): logger._id = version or "" logger._save_dir = Path(dir) / logger.save_dir logger._wandb_init["dir"] = Path(dir) / logger.save_dir - logger._wandb_init["name"] = self.name - logger._name = self.name logging.warning( f'"update_logger_directory" is True. Overwriting wandb logger "save_dir" to {logger._save_dir}' ) @@ -207,8 +207,8 @@ def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None): if callback.dirpath is None: callback.dirpath = Path(log_dir / "checkpoints") if callback.filename is None: - callback.filename = f'{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}' - ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + '-last' + callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}" + ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + "-last" def _handle_task_config(self, task_config, log_dir): try: @@ -219,7 +219,7 @@ def _handle_task_config(self, task_config, log_dir): with open(log_dir / "task.json", "w") as f: f.write(task_json) except Exception as e: - logging.warning(f'Saving task config failed: {e}. Skipping saving') + logging.warning(f"Saving task config failed: {e}. Skipping saving") def _setup_file_logging(self, log_dir): """Set up file logging based on rank settings.""" @@ -229,7 +229,7 @@ def _setup_file_logging(self, log_dir): # This is set if the env var NEMO_TESTING is set to True. nemo_testing = get_envbool(NEMO_ENV_VARNAME_TESTING, False) - log_file = log_dir / f'nemo_log_globalrank-{self.global_rank}_localrank-{self.local_rank}.txt' + log_file = log_dir / f"nemo_log_globalrank-{self.global_rank}_localrank-{self.local_rank}.txt" if self.log_local_rank_0_only and not nemo_testing and self.local_rank == 0: logging.add_file_handler(log_file) diff --git a/nemo/lightning/pytorch/callbacks/__init__.py b/nemo/lightning/pytorch/callbacks/__init__.py index 00637c9d57d4..dd2908e6f5e6 100644 --- a/nemo/lightning/pytorch/callbacks/__init__.py +++ b/nemo/lightning/pytorch/callbacks/__init__.py @@ -1,3 +1,6 @@ +from nemo.lightning.pytorch.callbacks.ddp_parity_checker import DdpParityChecker +from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform from nemo.lightning.pytorch.callbacks.nsys import NsysCallback @@ -6,8 +9,8 @@ from nemo.lightning.pytorch.callbacks.progress_bar import MegatronProgressBar from nemo.lightning.pytorch.callbacks.progress_printer import ProgressPrinter - __all__ = [ + "MemoryProfileCallback", "ModelCheckpoint", "ModelTransform", "PEFT", @@ -15,4 +18,6 @@ "MegatronProgressBar", "ProgressPrinter", "PreemptionCallback", + "DdpParityChecker", + "GarbageCollectionCallback", ] diff --git a/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py b/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py new file mode 100644 index 000000000000..b5c2127433d7 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py @@ -0,0 +1,74 @@ +from functools import cache + +import torch +from megatron.core.utils import check_param_hashes_across_dp_replicas +from pytorch_lightning.callbacks.callback import Callback + +from nemo.lightning import io +from nemo.utils import logging + + +@cache +def pl_has_dist_opt_with_ovelap(trainer): + optim_config = getattr(getattr(trainer.strategy.model, 'optim', None), 'config', None) + if not getattr(optim_config, 'use_distributed_optimizer', False): + return False + if not getattr(optim_config, 'overlap_param_gather', False): + return False + return True + + +def pl_check_param_hashes_across_dp_replicas(trainer): + if pl_has_dist_opt_with_ovelap(trainer): + for opt in self.optimizers: + opt.disable_pre_hook() + import megatron.core.parallel_state as mp + + res = check_param_hashes_across_dp_replicas([trainer.strategy.model]) + torch.distributed.barrier() + + all_res = [False for _ in range(mp.get_data_parallel_world_size())] + + torch.distributed.all_gather_object(all_res, res, group=mp.get_data_parallel_group_gloo()) + + if pl_has_dist_opt_with_ovelap(trainer): + for opt in self.optimizers: + opt.enable_pre_hook() + return all(all_res) + + +class DdpParityChecker(Callback, io.IOMixin): + """ + This callback enables weight parity checkping across DDP replicas with Mcore models. + + User can specify their desired interval for weights to be checked via the `interval` parameter. + + Args: + dir (Optional[str]): Directory to store the memory profile dump + + Example: + >>> callback = DdpParityChecker(interval=10) + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__(self, interval: int = 0): + """ + interval (int): How frequently to check DDP weights for errors. Default to 0 (off). + """ + assert interval > 0, "Expected interval to be > 0. A zero interval makes DdpParityChecker a no-op." + self.interval = interval + self.step = 0 + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, unused=0) -> None: + if self.step == self.interval - 1: + if pl_check_param_hashes_across_dp_replicas(trainer): + logging.info(f"DDP Param parity check passed for batch-id= {batch_idx}") + else: + trainer.should_stop = True + trainer.limit_val_batches = 0 + logging.info(f"DDP Param parity check FAILED for batch-id= {batch_idx}") + self.step = (self.step + 1) % self.interval + + def on_train_end(self, trainer, pl_module) -> None: + pl_check_param_hashes_across_dp_replicas(trainer) + logging.info("DDP Param parity check passed at end of training.") diff --git a/nemo/lightning/pytorch/callbacks/garbage_collection.py b/nemo/lightning/pytorch/callbacks/garbage_collection.py new file mode 100644 index 000000000000..a2b2bb6498a3 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/garbage_collection.py @@ -0,0 +1,68 @@ +import gc +from typing import Any + +import pytorch_lightning as pl +from nemo.utils import logging + + +class GarbageCollectionCallback(pl.Callback): + """Callback for synchronized manual Garbage Collection. This is required for distributed training + as all processes on different rank need to synchronize to garbage collect at the same time, without which + one process might hog or straggle all the rest of the processes. + + Migration from NeMo 1.0: + When mitrating from NeMo1, + - gc_interval = 0 implied no GC, simply do not add this callback to the trainer + - gc_interval > 0, this config is maps => gc_interval_train + + - env-var:NEMO_MANUAL_GC_IN_VALIDATION=0 or doesn't exist => Set gc_interval_val to a very high value that it does not practically run. + - env-var:NEMO_MANUAL_GC_IN_VALIDATION=1 => Set gc_interval_val to the same value as gc_interval + + Moving from boolean flag (NEMO_MANUAL_GC_IN_VALIDATION) to integer is to allow user to set a specific value based on the size of the + validation datasets. + + Note: This callback does not run gc at the start or the end of training or validation. + """ + + def __init__(self, gc_interval_train, gc_interval_val) -> None: + """_summary_ + + Args: + gc_interval (int, mandatory): Number of global train steps at which garbage collection is done. + gc_interval_val (int, mandatory): Number of global validation steps at which garbage collection is done. + """ + assert gc_interval_train > 0, "gc_interval_train should be an integer value larger than 0." + assert gc_interval_val > 0, "gc_interval_val should be an integer value larger than 0." + + super().__init__() + self.gc_interval_train = gc_interval_train + self.gc_interval_val = gc_interval_val + # As garbage collection is manually controlled, disable automatic garbage collector. + gc.disable() + # This counter is required as pl does not have a native way to track the validation step counter. + self.validation_global_step = 0 + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: pl.utilities.types.STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + if trainer.global_step % self.gc_interval_train == 0: + logging.info(f"Running garbage collection at train global_step: {trainer.global_step}") + gc.collect() + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: pl.utilities.types.STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + self.validation_global_step += 1 + if self.validation_global_step % self.gc_interval_val == 0: + logging.info(f"Running garbage collection at validation step: {self.validation_global_step}") + gc.collect() diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py new file mode 100644 index 000000000000..089479637f61 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -0,0 +1,78 @@ +import os + +import torch +from pytorch_lightning.callbacks.callback import Callback +from torch.utils.viz._cycles import warn_tensor_cycles + +from nemo.lightning import io +from nemo.utils import logging +from nemo.utils.get_rank import get_rank + + +class MemoryProfileCallback(Callback, io.IOMixin): + """ + This callback enables recording a timeline of memory allocations during training. + The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz + + More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/). + + Args: + dir (Optional[str]): Directory to store the memory profile dump + warn_cycles (Optional[bool]): Whether to enable [reference cycle detection](https://pytorch.org/blog/understanding-gpu-memory-2/) + rank (Optional[list[int]]): List of ranks to collect snapshot on, defaults to all if list is empty + + Example: + >>> callback = MemoryProfileCallback(dir="/mem_profile", ranks=[0]) + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__(self, dir: str = "/mem_profile", warn_cycles=True, ranks=[]): + + self.dir = dir + self.ranks = ranks + + os.makedirs(self.dir, exist_ok=True) + logging.info(f"Torch memory profiles will be written to: {self.dir}") + + if warn_cycles: + logging.info("Enabling reference cycle detector") + warn_tensor_cycles() + + def enable_on_rank(self) -> bool: + if not self.ranks: + return True + return get_rank() in self.ranks + + def setup(self, trainer, pl_module, stage) -> None: + """PyTorch Lightning hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end + We use it here to start recording the memory profiler. + """ + + if trainer.max_steps > 1000: + logging.warning( + f"Memory profiling creates snapshots during the entire training process, \ + where every iteration increases the size of the snapshot. \ + Try reducing trainer.max_steps to avoid running into issues" + ) + + if torch.distributed.is_initialized() and self.enable_on_rank(): + torch.cuda.memory._record_memory_history(max_entries=100000) + + def on_train_end(self, trainer, pl_module) -> None: + """PyTorch Lightning hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end + We use it here to finish memory profiling and write the snapshot. + """ + + logging.info( + f"on_train_batch_end rank: {get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}" + ) + + if torch.distributed.is_initialized() and self.enable_on_rank(): + rank = get_rank() + _snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle" + logging.info(f"Writing memory profile snapshot to {_snapshot_path}") + torch.cuda.memory._dump_snapshot(f"{_snapshot_path}") + torch.cuda.memory._record_memory_history(enabled=None) + logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}") diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 13a0caa98f0c..bacb7cb0af5c 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -26,8 +26,10 @@ def __init__( dataloader_type: Literal["single", "cyclic", "batch"] = "single", init_consumed_samples: int = 0, init_global_step: int = 0, + output_log: bool = True, ): self.seq_len = seq_len + self.output_log = output_log self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.rampup_batch_size = rampup_batch_size @@ -95,12 +97,14 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul self.prev_global_batch_size = self.current_global_batch_size consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step) - pl_module.log( - 'consumed_samples', - consumed_samples, - prog_bar=True, - batch_size=1, - ) + if self.output_log: + # You may need to turn off logging, for example when doing trainer.predict(model, data) + pl_module.log( + 'consumed_samples', + consumed_samples, + prog_bar=True, + batch_size=1, + ) self.prev_consumed_samples = consumed_samples @@ -108,12 +112,14 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul consumed_samples=consumed_samples, consistency_check=False, ) - pl_module.log( - "global_batch_size", - self.current_global_batch_size, - prog_bar=True, - batch_size=1, - ) + if self.output_log: + # You may need to turn off logging, for example when doing trainer.predict(model, data) + pl_module.log( + "global_batch_size", + self.current_global_batch_size, + prog_bar=True, + batch_size=1, + ) self.if_first_step = 1 @property diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 668b088a4864..d6ef18770fa4 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -484,12 +484,10 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP ) if self.log_train_loss: - # p2p now, broadcast later at ckpt + # p2p now, broadcast later at ckpt. only with pp, some ranks will log 0.0 + # WHICH IS OK because we broadcast later at checkpoint time _strategy_lib._sync_from_last_pipeline_stage(out, broadcast=False) - if torch.distributed.get_rank() == 0: - self.lightning_module.log( - 'reduced_train_loss', out, prog_bar=True, rank_zero_only=True, batch_size=1 - ) + self.lightning_module.log('reduced_train_loss', out, prog_bar=True, batch_size=1, sync_dist=False) return out diff --git a/nemo/lightning/run/__init__.py b/nemo/lightning/run/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/lightning/run/plugins.py b/nemo/lightning/run/plugins.py new file mode 100644 index 000000000000..0f6a76d4799f --- /dev/null +++ b/nemo/lightning/run/plugins.py @@ -0,0 +1,165 @@ +import copy +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Optional + +import nemo_run as run +import yaml +from nemo_run.core.serialization.yaml import YamlSerializer +from pytorch_lightning import Callback +from pytorch_lightning.loggers import WandbLogger + +from nemo.lightning.pytorch.callbacks import NsysCallback, PreemptionCallback +from nemo.utils import logging + +# This file contains plugins based on NeMo-Run's run.Plugin API. +# Plugins operate both on a configured task and an executor at the same time, and are specific to NeMo-Run. +# If you are adding functionality that goes directly into the Pytorch Lightning trainer, you may consider adding a callback instead of a plugin. + + +def _merge_callbacks(partial: run.Partial, callbacks: list[run.Config[Callback]]): + if hasattr(partial, "trainer"): + if hasattr(partial.trainer, "callbacks"): + for callback in callbacks: + if callback not in partial.trainer.callbacks: + partial.trainer.callbacks.append(callback) + else: + partial.trainer.callbacks = copy.deepcopy(callbacks) + + +@dataclass(kw_only=True) +class PreemptionPlugin(run.Plugin): + """ + A plugin for setting up Preemption callback and preemption signals. + + Args: + preempt_time (int): The time, in seconds, before the task's time limit at which the executor + will send a SIGTERM preemption signal. This allows tasks to be gracefully + stopped before reaching their time limit, reducing waste and + promoting fair resource usage. The default value is 300 seconds (5 minutes). + This is only supported for ``run.SlurmExecutor``. + callbacks (list[run.Config[Callback]]): A list of callback configurations that the plugin + will merge with the task's existing callbacks. + By default, the list includes NeMo's preemption callback. + """ + + preempt_time: int = 300 + callbacks: list[run.Config[Callback]] = field(default_factory=lambda: [run.Config(PreemptionCallback)]) + + def setup(self, task: run.Partial | run.Script, executor: run.Executor): + if isinstance(task, run.Script): + logging.warning( + f"The {self.__class__.__name__} will have no effect on the task as it's an instance of run.Script" + ) + return + + if isinstance(executor, run.SlurmExecutor): + # Sends a SIGTERM self.preempt_time seconds before hitting time limit + logging.info( + f"{self.__class__.__name__} will send a SIGTERM {self.preempt_time} seconds before the job's time limit for your Slurm executor." + ) + executor.signal = f"TERM@{self.preempt_time}" + + _merge_callbacks(task, callbacks=self.callbacks) + + +@dataclass(kw_only=True) +class NsysPlugin(run.Plugin): + """ + A plugin for nsys profiling. + + The NsysPlugin allows you to profile your run using nsys. + You can specify when to start and end the profiling, on which ranks to run the profiling, + and what to trace during profiling. + + Args: + start_step (int): The step at which to start the nsys profiling. + end_step (int): The step at which to end the nsys profiling. + ranks (Optional[list[int]]): The ranks on which to run the nsys profiling. If not specified, + profiling will be run on rank 0. + nsys_trace (Optional[list[str]]): The events to trace during profiling. If not specified, + 'nvtx' and 'cuda' events will be traced. + """ + + start_step: int + end_step: int + ranks: Optional[list[int]] = None + nsys_trace: Optional[list[str]] = None + + def setup(self, task: run.Partial | run.Script, executor: run.Executor): + if isinstance(task, run.Partial): + nsys_callback = run.Config( + NsysCallback, + start_step=self.start_step, + end_step=self.end_step, + ranks=self.ranks or [0], + ) + callbacks: list[run.Config[Callback]] = [nsys_callback] # type: ignore + _merge_callbacks(task, callbacks=callbacks) + + launcher = executor.get_launcher() + launcher.nsys_profile = True + launcher.nsys_trace = self.nsys_trace or ["nvtx", "cuda"] + + +@dataclass(kw_only=True) +class WandbPlugin(run.Plugin): + """ + A plugin for setting up Weights & Biases. + + This plugin sets a ``WandbLogger`` to ``NeMoLogger``'s ``wandb`` arg, + which in turn initializes the Pytorch Lightning `WandbLogger `_. + + This plugin is only activated if the ``WANDB_API_KEY`` environment variable is set. + The ``WANDB_API_KEY`` environment variables will also be set in the executor's environment variables. + Follow https://docs.wandb.ai/quickstart to retrieve your ``WANDB_API_KEY``. + + If `log_task_config` is True, the plugin will log the task configuration as a config dictionary + to the Weights and Biases logger. + + Args: + name (str): The name for the Weights & Biases run. + logger_fn (Callable[..., run.Config[WandbLogger]]): A callable that returns a Config of ``WandbLogger`` + log_task_config (bool, optional): Whether to log the task configuration to the logger. + Defaults to True. + + Raises: + logging.warning: If the task is an instance of `run.Script`, as the plugin has no effect on such tasks. + """ + + name: str + logger_fn: Callable[..., run.Config[WandbLogger]] + log_task_config: bool = True + + def setup(self, task: run.Partial | run.Script, executor: run.Executor): + if isinstance(task, run.Script): + logging.warning( + f"The {self.__class__.__name__} will have no effect on the task as it's an instance of run.Script" + ) + return + + if "WANDB_API_KEY" in os.environ: + executor.env_vars["WANDB_API_KEY"] = os.environ["WANDB_API_KEY"] + + if hasattr(task, "log") and hasattr(task.log, "wandb"): + task.log.wandb = self.logger_fn(name=self.name) + if self.log_task_config: + partial_config = yaml.safe_load(YamlSerializer().serialize(task)) + partial_config["experiment"] = { + "id": self.experiment_id, + "task_name": self.name, + "executor": executor.info(), + "remote_directory": ( + os.path.join(executor.tunnel.job_dir, Path(executor.job_dir).name) + if isinstance(executor, run.SlurmExecutor) + else None + ), + "local_directory": executor.job_dir, + } + task.log.wandb.config = partial_config + else: + logging.warning( + f"The {self.__class__.__name__} will have no effect as WANDB_API_KEY environment variable is not set." + ) diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index 1b3397f69033..171abce41f37 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -4,6 +4,6 @@ hydra-core>1.3,<=1.3.2 omegaconf<=2.3 pytorch-lightning>2.2.1 torchmetrics>=0.11.0 -transformers +transformers>=4.44.0 wandb webdataset>=0.2.86 diff --git a/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py b/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py new file mode 100644 index 000000000000..a1db7cec4f23 --- /dev/null +++ b/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py @@ -0,0 +1,439 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +# This script would evaluate an N-gram language model in ARPA format in +# fusion with WFST decoders on top of a trained ASR model with CTC decoder. +# NeMo's WFST decoders use WFST decoding graphs made from ARPA LMs +# to find the best candidates. This script supports BPE level encodings only +# and models which is detected automatically from the type of the model. +# You may train the LM model with e.g. SRILM. + +# Config Help + +To discover all arguments of the script, please run : +python eval_wfst_decoding_ctc.py --help +python eval_wfst_decoding_ctc.py --cfg job + +# USAGE + +python eval_wfst_decoding_ctc.py nemo_model_file= \ + input_manifest= \ + arpa_model_file= \ + decoding_wfst_file= \ + beam_width=[] \ + lm_weight=[] \ + decoding_mode= \ + decoding_search_type= \ + open_vocabulary_decoding= \ + preds_output_folder= \ + probs_cache_file=null + ... + + +# Grid Search for Hyper parameters + +For grid search, you can provide a list of arguments as follows - + + beam_width=[5.0,10.0,15.0,20.0] \ + lm_weight=[0.1,0.5,0.6,0.7,0.8,0.9,1.0,1.1,1.2,1.3,1.4,1.5,2.0] \ + +""" + + +import contextlib +import json +import os +import pickle +from dataclasses import dataclass, field, is_dataclass +from pathlib import Path +from typing import List, Optional + +import editdistance +import numpy as np +import torch +from omegaconf import MISSING, OmegaConf +from sklearn.model_selection import ParameterGrid +from tqdm.auto import tqdm + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.models import EncDecHybridRNNTCTCModel +from nemo.collections.asr.parts.submodules import ctc_beam_decoding +from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig +from nemo.core.config import hydra_runner +from nemo.utils import logging + +# fmt: off + + +@dataclass +class EvalWFSTNGramConfig: + """ + Evaluate an ASR model with WFST decoding and n-gram ARPA language model. + """ + # # The path of the '.nemo' file of the ASR model or the name of a pretrained model (ngc / huggingface) + nemo_model_file: str = MISSING + + # File paths + input_manifest: str = MISSING # The manifest file of the evaluation set + arpa_model_file: Optional[str] = None # The path of the ARPA model file + decoding_wfst_file: Optional[str] = None # The path of the decoding WFST file + preds_output_folder: Optional[str] = None # The optional folder where the predictions are stored + probs_cache_file: Optional[str] = None # The cache file for storing the logprobs of the model + + # Parameters for inference + acoustic_batch_size: int = 16 # The batch size to calculate log probabilities + beam_batch_size: int = 512 # The batch size to be used for beam search decoding + device: str = "cuda" # The device to load the model onto to calculate log probabilities and run WFST decoding + use_amp: bool = False # Whether to use AMP if available to calculate log probabilities + + # WFST decoding hyperparameters + + beam_width: List[float] = field(default_factory=lambda: [10]) # The width or list of the beam widths for the WFST decoding + lm_weight: List[float] = field(default_factory=lambda: [1.0]) # The language model weight parameter or list of parameters for the WFST decoding + + open_vocabulary_decoding: bool = False # Whether to use open vocabulary mode for WFST decoding + decoding_mode: str = "nbest" + decoding_search_type: str = "riva" + decoding: ctc_beam_decoding.WfstCTCInferConfig = field( + default_factory=lambda: ctc_beam_decoding.WfstCTCInferConfig(beam_size=1) + ) + + text_processing: Optional[TextProcessingConfig] = field(default_factory=lambda: TextProcessingConfig( + punctuation_marks = ".,?", + separate_punctuation = False, + do_lowercase = False, + rm_punctuation = False, + )) +# fmt: on + + +def beam_search_eval( + model: nemo_asr.models.ASRModel, + cfg: EvalWFSTNGramConfig, + all_probs: List[torch.Tensor], + target_transcripts: List[str], + preds_output_file: str = None, + lm_weight: float = 1.0, + beam_width: float = 10.0, + beam_batch_size: int = 512, + progress_bar: bool = True, + punctuation_capitalization: PunctuationCapitalization = None, +): + level = logging.getEffectiveLevel() + logging.setLevel(logging.CRITICAL) + # Reset config + if isinstance(model, EncDecHybridRNNTCTCModel): + model.change_decoding_strategy(decoding_cfg=None, decoder_type="ctc") + else: + model.change_decoding_strategy(None) + + # Override the beam search config with current search candidate configuration + cfg.decoding.beam_width = beam_width + cfg.decoding.lm_weight = lm_weight + cfg.decoding.open_vocabulary_decoding = cfg.open_vocabulary_decoding + cfg.decoding.return_best_hypothesis = False + cfg.decoding.arpa_lm_path = cfg.arpa_model_file + cfg.decoding.wfst_lm_path = cfg.decoding_wfst_file + cfg.decoding.device = cfg.device + cfg.decoding.decoding_mode = cfg.decoding_mode + cfg.decoding.search_type = cfg.decoding_search_type + + # Update model's decoding strategy config + model.cfg.decoding.strategy = "wfst" + model.cfg.decoding.wfst = cfg.decoding + + # Update model's decoding strategy + if isinstance(model, EncDecHybridRNNTCTCModel): + model.change_decoding_strategy(model.cfg.decoding, decoder_type='ctc') + decoding = model.ctc_decoding + else: + model.change_decoding_strategy(model.cfg.decoding) + decoding = model.decoding + logging.setLevel(level) + + wer_dist_first = cer_dist_first = 0 + wer_dist_best = cer_dist_best = 0 + words_count = 0 + chars_count = 0 + sample_idx = 0 + if preds_output_file: + out_file = open(preds_output_file, 'w', encoding='utf_8', newline='\n') + + if progress_bar: + it = tqdm( + range(int(np.ceil(len(all_probs) / beam_batch_size))), + desc=f"Beam search decoding with width={beam_width}, lm_weight={lm_weight}", + ncols=120, + ) + else: + it = range(int(np.ceil(len(all_probs) / beam_batch_size))) + for batch_idx in it: + # disabling type checking + probs_batch = all_probs[batch_idx * beam_batch_size : (batch_idx + 1) * beam_batch_size] + probs_lens = torch.tensor([prob.shape[0] for prob in probs_batch]) + with torch.no_grad(): + packed_batch = torch.zeros(len(probs_batch), max(probs_lens), probs_batch[0].shape[-1], device='cpu') + + for prob_index in range(len(probs_batch)): + packed_batch[prob_index, : probs_lens[prob_index], :] = probs_batch[prob_index].to( + device=packed_batch.device, dtype=packed_batch.dtype + ) + + _, beams_batch = decoding.ctc_decoder_predictions_tensor( + packed_batch, + decoder_lengths=probs_lens, + return_hypotheses=True, + ) + + for beams_idx, beams in enumerate(beams_batch): + target = target_transcripts[sample_idx + beams_idx] + target_split_w = target.split() + target_split_c = list(target) + words_count += len(target_split_w) + chars_count += len(target_split_c) + wer_dist_min = cer_dist_min = 10000 + for candidate_idx, candidate in enumerate(beams): # type: (int, ctc_beam_decoding.rnnt_utils.Hypothesis) + pred_text = candidate.text + if cfg.text_processing.do_lowercase: + pred_text = punctuation_capitalization.do_lowercase([pred_text])[0] + if cfg.text_processing.rm_punctuation: + pred_text = punctuation_capitalization.rm_punctuation([pred_text])[0] + if cfg.text_processing.separate_punctuation: + pred_text = punctuation_capitalization.separate_punctuation([pred_text])[0] + pred_split_w = pred_text.split() + wer_dist = editdistance.eval(target_split_w, pred_split_w) + pred_split_c = list(pred_text) + cer_dist = editdistance.eval(target_split_c, pred_split_c) + + wer_dist_min = min(wer_dist_min, wer_dist) + cer_dist_min = min(cer_dist_min, cer_dist) + + if candidate_idx == 0: + # first candidate + wer_dist_first += wer_dist + cer_dist_first += cer_dist + + score = candidate.score + if preds_output_file: + out_file.write(f'{pred_text}\t{score}\n') + wer_dist_best += wer_dist_min + cer_dist_best += cer_dist_min + sample_idx += len(probs_batch) + + if preds_output_file: + out_file.close() + logging.info(f"Stored the predictions of beam search decoding at '{preds_output_file}'.") + + logging.info( + 'WER/CER with beam search decoding and N-gram model = {:.2%}/{:.2%}'.format( + wer_dist_first / words_count, cer_dist_first / chars_count + ) + ) + logging.info( + 'Oracle WER/CER in candidates with perfect LM= {:.2%}/{:.2%}'.format( + wer_dist_best / words_count, cer_dist_best / chars_count + ) + ) + logging.info(f"=================================================================================") + + return wer_dist_first / words_count, cer_dist_first / chars_count + + +@hydra_runner(config_path=None, config_name='EvalWFSTNGramConfig', schema=EvalWFSTNGramConfig) +def main(cfg: EvalWFSTNGramConfig): + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) # type: EvalWFSTNGramConfig + + if cfg.nemo_model_file.endswith('.nemo'): + asr_model = nemo_asr.models.ASRModel.restore_from(cfg.nemo_model_file, map_location=torch.device(cfg.device)) + else: + logging.warning( + "nemo_model_file does not end with .nemo, therefore trying to load a pretrained model with this name." + ) + asr_model = nemo_asr.models.ASRModel.from_pretrained( + cfg.nemo_model_file, map_location=torch.device(cfg.device) + ) + + target_transcripts = [] + manifest_dir = Path(cfg.input_manifest).parent + with open(cfg.input_manifest, 'r', encoding='utf_8') as manifest_file: + audio_file_paths = [] + for line in tqdm(manifest_file, desc=f"Reading Manifest {cfg.input_manifest} ...", ncols=120): + data = json.loads(line) + audio_file = Path(data['audio_filepath']) + if not audio_file.is_file() and not audio_file.is_absolute(): + audio_file = manifest_dir / audio_file + target_transcripts.append(data['text']) + audio_file_paths.append(str(audio_file.absolute())) + + punctuation_capitalization = PunctuationCapitalization(cfg.text_processing.punctuation_marks) + if cfg.text_processing.do_lowercase: + target_transcripts = punctuation_capitalization.do_lowercase(target_transcripts) + if cfg.text_processing.rm_punctuation: + target_transcripts = punctuation_capitalization.rm_punctuation(target_transcripts) + if cfg.text_processing.separate_punctuation: + target_transcripts = punctuation_capitalization.separate_punctuation(target_transcripts) + + if cfg.probs_cache_file and os.path.exists(cfg.probs_cache_file): + logging.info(f"Found a pickle file of probabilities at '{cfg.probs_cache_file}'.") + logging.info(f"Loading the cached pickle file of probabilities from '{cfg.probs_cache_file}' ...") + with open(cfg.probs_cache_file, 'rb') as probs_file: + all_probs = pickle.load(probs_file) + + if len(all_probs) != len(audio_file_paths): + raise ValueError( + f"The number of samples in the probabilities file '{cfg.probs_cache_file}' does not " + f"match the manifest file. You may need to delete the probabilities cached file." + ) + else: + + @contextlib.contextmanager + def default_autocast(): + yield + + if cfg.use_amp: + if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): + logging.info("AMP is enabled!\n") + autocast = torch.cuda.amp.autocast + + else: + autocast = default_autocast + else: + + autocast = default_autocast + + with autocast(): + with torch.no_grad(): + if isinstance(asr_model, EncDecHybridRNNTCTCModel): + asr_model.cur_decoder = 'ctc' + all_hyps = asr_model.transcribe( + audio_file_paths, batch_size=cfg.acoustic_batch_size, return_hypotheses=True + ) + all_logits = [h.y_sequence for h in all_hyps] + + all_probs = all_logits + if cfg.probs_cache_file: + os.makedirs(os.path.split(cfg.probs_cache_file)[0], exist_ok=True) + logging.info(f"Writing pickle files of probabilities at '{cfg.probs_cache_file}'...") + with open(cfg.probs_cache_file, 'wb') as f_dump: + pickle.dump(all_probs, f_dump) + + wer_dist_greedy = 0 + cer_dist_greedy = 0 + words_count = 0 + chars_count = 0 + for batch_idx, probs in enumerate(all_probs): + preds = np.argmax(probs, axis=1) + preds_tensor = preds.to(device='cpu').unsqueeze(0) + preds_lens = torch.tensor([preds_tensor.shape[1]], device='cpu') + if isinstance(asr_model, EncDecHybridRNNTCTCModel): + pred_text = asr_model.ctc_decoding.ctc_decoder_predictions_tensor(preds_tensor, preds_lens)[0][0] + else: + pred_text = asr_model.decoding.ctc_decoder_predictions_tensor(preds_tensor, preds_lens)[0][0] + + if cfg.text_processing.do_lowercase: + pred_text = punctuation_capitalization.do_lowercase([pred_text])[0] + if cfg.text_processing.rm_punctuation: + pred_text = punctuation_capitalization.rm_punctuation([pred_text])[0] + if cfg.text_processing.separate_punctuation: + pred_text = punctuation_capitalization.separate_punctuation([pred_text])[0] + + pred_split_w = pred_text.split() + target_split_w = target_transcripts[batch_idx].split() + pred_split_c = list(pred_text) + target_split_c = list(target_transcripts[batch_idx]) + + wer_dist = editdistance.eval(target_split_w, pred_split_w) + cer_dist = editdistance.eval(target_split_c, pred_split_c) + + wer_dist_greedy += wer_dist + cer_dist_greedy += cer_dist + words_count += len(target_split_w) + chars_count += len(target_split_c) + + logging.info('Greedy WER/CER = {:.2%}/{:.2%}'.format(wer_dist_greedy / words_count, cer_dist_greedy / chars_count)) + + asr_model = asr_model.to('cpu') + + if (cfg.arpa_model_file is None or not os.path.exists(cfg.arpa_model_file)) and ( + cfg.decoding_wfst_file is None or not os.path.exists(cfg.decoding_wfst_file) + ): + raise FileNotFoundError( + f"Could not find both the ARPA model file `{cfg.arpa_model_file}` " + f"and the decoding WFST file `{cfg.decoding_wfst_file}`." + ) + + if cfg.beam_width is None or cfg.lm_weight is None: + raise ValueError("beam_width and lm_weight are needed to perform WFST decoding.") + params = {'beam_width': cfg.beam_width, 'lm_weight': cfg.lm_weight} + hp_grid = ParameterGrid(params) + hp_grid = list(hp_grid) + + best_wer_beam_width, best_cer_beam_width = None, None + best_wer_lm_weight, best_cer_lm_weight = None, None + best_wer, best_cer = 1e6, 1e6 + + logging.info(f"==============================Starting the beam search decoding===============================") + logging.info(f"Grid search size: {len(hp_grid)}") + logging.info(f"It may take some time...") + logging.info(f"==============================================================================================") + + if cfg.preds_output_folder and not os.path.exists(cfg.preds_output_folder): + os.mkdir(cfg.preds_output_folder) + for hp in hp_grid: + if cfg.preds_output_folder: + preds_output_file = os.path.join( + cfg.preds_output_folder, + f"preds_out_beam_width{hp['beam_width']}_lm_weight{hp['lm_weight']}.tsv", + ) + else: + preds_output_file = None + + candidate_wer, candidate_cer = beam_search_eval( + asr_model, + cfg, + all_probs=all_probs, + target_transcripts=target_transcripts, + preds_output_file=preds_output_file, + beam_width=hp["beam_width"], + lm_weight=hp["lm_weight"], + beam_batch_size=cfg.beam_batch_size, + progress_bar=True, + punctuation_capitalization=punctuation_capitalization, + ) + + if candidate_cer < best_cer: + best_cer_beam_width = hp["beam_width"] + best_cer_lm_weight = hp["lm_weight"] + best_cer = candidate_cer + + if candidate_wer < best_wer: + best_wer_beam_width = hp["beam_width"] + best_wer_lm_weight = hp["lm_weight"] + best_wer = candidate_wer + + logging.info( + f'Best WER Candidate = {best_wer:.2%} :: Beam size = {best_wer_beam_width}, LM weight = {best_wer_lm_weight}' + ) + + logging.info( + f'Best CER Candidate = {best_cer:.2%} :: Beam size = {best_cer_beam_width}, LM weight = {best_cer_lm_weight}' + ) + logging.info(f"=================================================================================") + + +if __name__ == '__main__': + main() diff --git a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py index 1a0a13709421..7a7484bf9c20 100644 --- a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py @@ -26,7 +26,7 @@ ''' Example -CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ +CUDA_VISIBLE_DEVICES="0" python /opt/NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ --input_name_or_path \ --output_path \ --mamba_ssm_ngroups 8 \ @@ -63,10 +63,24 @@ def get_args(): def convert(args): - checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu') + checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu')['model'] new_state_dict = {} if 'backbone' in list(checkpoint_weights.keys())[0]: + if 'model' in list(checkpoint_weights.keys())[0]: + checkpoint_weights = {key.replace('model.', '', 1): value for key, value in checkpoint_weights.items()} + + # Codestral Mamba Model Tokenizer Settings + tokenizer_library = 'megatron' + tokenizer_type = 'GPTSentencePieceTokenizer' + tokenizer_model = args.tokenizer_model_dir + + else: + + # Tri Dao and Albert Gu Mamba Model Tokenizer Settings + tokenizer_library = 'huggingface' + tokenizer_type = 'EleutherAI/gpt-neox-20b' + tokenizer_model = None layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'backbone\.layers\.\d+\.', key)] layer_numbers = set(int(re.search(r'backbone\.layers\.(\d+)\.', key).group(1)) for key in layer_keys) @@ -103,11 +117,6 @@ def convert(args): old_key = f'backbone.layers.{i}.{attr}' new_state_dict[new_key] = checkpoint_weights[old_key] - # Tokenizer settings - tokenizer_library = 'huggingface' - tokenizer_type = 'EleutherAI/gpt-neox-20b' - tokenizer_model = None - else: layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'decoder\.layers\.\d+\.', key)] @@ -124,11 +133,6 @@ def convert(args): tokenizer_type = 'GPTSentencePieceTokenizer' tokenizer_model = args.tokenizer_model_dir - # Tokenizer settings - tokenizer_library = 'megatron' - tokenizer_type = 'GPTSentencePieceTokenizer' - tokenizer_model = args.tokenizer_model_dir - layers = defaultdict(list) for key in new_state_dict.keys(): diff --git a/scripts/deploy/multimodal/deploy_triton.py b/scripts/deploy/multimodal/deploy_triton.py index d0bf8f10548a..18463a3fc24a 100755 --- a/scripts/deploy/multimodal/deploy_triton.py +++ b/scripts/deploy/multimodal/deploy_triton.py @@ -35,6 +35,16 @@ def get_args(argv): formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=f"Deploy nemo models to Triton", ) + # default modality is vision, can be changed to audio + parser.add_argument( + "-mod", + "--modality", + type=str, + required=False, + default="vision", + choices=["vision", "audio"], + help="Modality of the model", + ) parser.add_argument("-vc", "--visual_checkpoint", type=str, help="Source .nemo file for visual model") parser.add_argument( "-lc", @@ -48,7 +58,7 @@ def get_args(argv): "--model_type", type=str, required=True, - choices=["neva", "video-neva", "lita", "vila", "vita"], + choices=["neva", "video-neva", "lita", "vila", "vita", "salm"], help="Type of the model that is supported.", ) parser.add_argument( @@ -123,8 +133,7 @@ def get_trt_deployable(args): raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") exporter = TensorRTMMExporter( - model_dir=trt_path, - load_model=(args.visual_checkpoint is None), + model_dir=trt_path, load_model=(args.visual_checkpoint is None), modality=args.modality ) if args.visual_checkpoint is not None: diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index c0acd97e1b50..0ec6264d6bf0 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -128,7 +128,8 @@ def get_args(argv): default=False, action='store_true', help='Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ - It is beneifical when batchxnum_heads cannot fully utilize GPU.', + It is beneifical when batchxnum_heads cannot fully utilize GPU. \ + Only available when using c++ runtime.', ) parser.add_argument( "-es", '--enable_streaming', default=False, action='store_true', help="Enables streaming sentences." @@ -274,6 +275,7 @@ def get_trtllm_deployable(args): lora_ckpt_list=args.lora_ckpt, load_model=(args.nemo_checkpoint is None), use_python_runtime=(not args.use_cpp_runtime), + multi_block_mode=args.multi_block_mode, ) if args.nemo_checkpoint is not None: @@ -296,7 +298,6 @@ def get_trtllm_deployable(args): paged_kv_cache=(not args.no_paged_kv_cache), remove_input_padding=(not args.disable_remove_input_padding), dtype=args.dtype, - enable_multi_block_mode=args.multi_block_mode, use_lora_plugin=args.use_lora_plugin, lora_target_modules=args.lora_target_modules, max_lora_rank=args.max_lora_rank, diff --git a/scripts/export/export_mm_to_trtllm.py b/scripts/export/export_mm_to_trtllm.py new file mode 100644 index 000000000000..e7389f6e07af --- /dev/null +++ b/scripts/export/export_mm_to_trtllm.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports multimodal model to TensorRT and do a local inference test. +For multimodal model, it supports the following models: +- NEVA +- Video-NEVA +- LITA +- VILA +- VITA +- SALM +""" + +import argparse +import os + +from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter + + +def parse_args(): + parser = argparse.ArgumentParser(description='Export multimodal model to TensorRT') + parser.add_argument('--output_dir', required=True, help='Directory to save the exported model') + parser.add_argument( + '--visual_checkpoint_path', + required=True, + help='Path to the visual model checkpoint or perception model checkpoint', + ) + parser.add_argument('--llm_checkpoint_path', required=True, help='Source .nemo file for llm') + parser.add_argument( + '--modality', + default="vision", + choices=["vision", "audio"], + help="Modality of the model", + ) + parser.add_argument( + '--model_type', + type=str, + required=True, + choices=["neva", "video-neva", "lita", "vila", "vita", "salm"], + help="Type of the model that is supported.", + ) + + parser.add_argument( + '--llm_model_type', + type=str, + required=True, + choices=["gptnext", "gpt", "llama", "falcon", "starcoder", "mixtral", "gemma"], + help="Type of LLM. gptnext, gpt, llama, falcon, and starcoder are only supported." + " gptnext and gpt are the same and keeping it for backward compatibility", + ) + + parser.add_argument('--tensor_parallel_size', type=int, default=1, help='tensor parallelism size') + parser.add_argument('--max_input_len', type=int, default=4096, help='Maximum input length') + parser.add_argument('--max_output_len', type=int, default=256, help='Maximum output length') + parser.add_argument('--max_batch_size', type=int, default=1, help='Maximum batch size') + parser.add_argument( + '--vision_max_batch_size', + type=int, + default=1, + help='Max batch size of the visual inputs, for lita/vita model with video inference, this should be set to 256', + ) + parser.add_argument('--max_multimodal_len', type=int, default=3072, help='Maximum multimodal length') + parser.add_argument( + "--dtype", + choices=["bfloat16", "float16"], + default="bfloat16", + type=str, + help="dtype of the model on TensorRT", + ) + parser.add_argument( + '--delete_existing_files', action='store_true', help='Delete existing files in the output directory' + ) + parser.add_argument( + '--test_export_only', action='store_true', help='Only test the export without saving the model' + ) + parser.add_argument('--input_text', help='Input text for inference') + parser.add_argument('--input_media', default=None, help='Input media file for inference') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference') + parser.add_argument('--max_output', type=int, default=128, help='Maximum output length for inference') + parser.add_argument('--top_k', type=int, default=1, help='Top k for sampling') + parser.add_argument('--top_p', type=float, default=0.0, help='Top p for sampling') + parser.add_argument("--temperature", default=1.0, type=float, help="temperature") + parser.add_argument("--repetition_penalty", default=1.0, type=float, help="repetition_penalty") + parser.add_argument("--num_beams", default=1, type=int, help="num_beams") + + args = parser.parse_args() + return args + + +def main(args): + exporter = TensorRTMMExporter(model_dir=args.output_dir, load_model=False, modality=args.modality) + exporter.export( + visual_checkpoint_path=args.visual_checkpoint_path, + llm_checkpoint_path=args.llm_checkpoint_path, + model_type=args.model_type, + llm_model_type=args.llm_model_type, + tensor_parallel_size=args.tensor_parallel_size, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + max_batch_size=args.max_batch_size, + vision_max_batch_size=args.vision_max_batch_size, + max_multimodal_len=args.max_multimodal_len, + dtype=args.dtype, + delete_existing_files=args.delete_existing_files, + load_model=not args.test_export_only, + ) + test_inference = not args.test_export_only + if test_inference: + assert args.input_media is not None, "Input media file is required for inference" + assert os.path.exists(args.input_media), f"Input media file {args.input_media} does not exist" + output = exporter.forward( + input_text=args.input_text, + input_media=args.input_media, + batch_size=args.batch_size, + max_output_len=args.max_output, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + num_beams=args.num_beams, + ) + print(output) + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/scripts/installers/install_riva_decoder.sh b/scripts/installers/install_riva_decoder.sh new file mode 100755 index 000000000000..4e6e99b570ab --- /dev/null +++ b/scripts/installers/install_riva_decoder.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pip install kaldifst kaldilm riva-asrlib-decoder diff --git a/scripts/nlp_language_modeling/niv2/preprocess_niv2.py b/scripts/nlp_language_modeling/niv2/preprocess_niv2.py index 073d6da8f32c..6119768e66f2 100644 --- a/scripts/nlp_language_modeling/niv2/preprocess_niv2.py +++ b/scripts/nlp_language_modeling/niv2/preprocess_niv2.py @@ -18,8 +18,6 @@ from argparse import ArgumentParser from multiprocessing import Pool -from sacremoses import MosesDetokenizer - from nemo.collections.common.tokenizers import AutoTokenizer @@ -99,6 +97,8 @@ def write_dataset_to_file(file_name, output_file_name, detokenizer, tokenizer, i def process_folder(data_folder, output_folder, splits_file, remove_newline): + from sacremoses import MosesDetokenizer + detokenizer = MosesDetokenizer('en') tokenizer = AutoTokenizer("gpt2") assert os.path.isdir(data_folder) @@ -162,10 +162,15 @@ def process_folder(data_folder, output_folder, splits_file, remove_newline): help="Path to output folder where JSONL files will be written.", ) parser.add_argument( - "--splits_file_path", type=str, default="default", help="Path to the file that contains splits. ex: ", + "--splits_file_path", + type=str, + default="default", + help="Path to the file that contains splits. ex: ", ) parser.add_argument( - "--remove_newline", action="store_true", help="Whether to remove newlines from the input and output.", + "--remove_newline", + action="store_true", + help="Whether to remove newlines from the input and output.", ) args = parser.parse_args() process_folder(args.niv2_dataset_path, args.jsonl_output_path, args.splits_file_path, args.remove_newline) diff --git a/scripts/nlp_language_modeling/t0/t0_dataset_preproc.py b/scripts/nlp_language_modeling/t0/t0_dataset_preproc.py index 618c02c0cc13..53bed36ff8d0 100644 --- a/scripts/nlp_language_modeling/t0/t0_dataset_preproc.py +++ b/scripts/nlp_language_modeling/t0/t0_dataset_preproc.py @@ -19,7 +19,6 @@ from multiprocessing import Pool import tensorflow as tf -from sacremoses import MosesDetokenizer from tasks_splits_and_features import _TASK_SPLITS_AND_FEATURES_DICT @@ -136,6 +135,8 @@ def process_folder(data_folder, folder_name, output_folder, detokenizer, remove_ def process_all_folders(data_folder, output_folder, remove_newlines): + from sacremoses import MosesDetokenizer + detokenizer = MosesDetokenizer('en') assert os.path.isdir(data_folder) if not os.path.exists(output_folder): @@ -170,7 +171,9 @@ def process_all_folders(data_folder, output_folder, remove_newlines): help="Path to output folder where JSONL files will be written.", ) parser.add_argument( - "--remove_newlines", action="store_true", help="Whether to remove newlines from the input and output.", + "--remove_newlines", + action="store_true", + help="Whether to remove newlines from the input and output.", ) args = parser.parse_args() process_all_folders(args.p3_dataset_path, args.jsonl_output_path, args.remove_newlines) diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 0d7c555ee778..247906247091 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -24,6 +24,7 @@ from nemo.collections.asr.data import audio_to_text from nemo.collections.asr.models import configs from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE +from nemo.collections.asr.parts.submodules import ctc_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig from nemo.collections.common import tokenizers from nemo.utils.config_utils import assert_dataclass_signature_match @@ -279,6 +280,34 @@ def test_decoding_change(self, asr_model): assert asr_model.decoding.preserve_alignments is True assert asr_model.decoding.compute_timestamps is True + new_strategy = DictConfig({}) + new_strategy.strategy = 'beam' + new_strategy.beam = DictConfig({'beam_size': 1}) + asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(asr_model.decoding.decoding, beam_decode.BeamCTCInfer) + assert asr_model.decoding.decoding.search_type == "default" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'pyctcdecode' + new_strategy.beam = DictConfig({'beam_size': 1}) + asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(asr_model.decoding.decoding, beam_decode.BeamCTCInfer) + assert asr_model.decoding.decoding.search_type == "pyctcdecode" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'flashlight' + new_strategy.beam = DictConfig({'beam_size': 1}) + asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(asr_model.decoding.decoding, beam_decode.BeamCTCInfer) + assert asr_model.decoding.decoding.search_type == "flashlight" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'wfst' + new_strategy.beam = DictConfig({'beam_size': 1}) + asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(asr_model.decoding.decoding, beam_decode.WfstCTCInfer) + assert asr_model.decoding.decoding.search_type == "riva" + @pytest.mark.unit def test_ASRDatasetConfig_for_AudioToBPEDataset(self): # ignore some additional arguments as dataclass is generic diff --git a/tests/collections/common/test_lhotse_nemo_adapters.py b/tests/collections/common/test_lhotse_nemo_adapters.py new file mode 100644 index 000000000000..a76116b10dd7 --- /dev/null +++ b/tests/collections/common/test_lhotse_nemo_adapters.py @@ -0,0 +1,188 @@ +import numpy as np +import pytest +from lhotse import AudioSource, CutSet, MonoCut, Recording, SupervisionSegment +from lhotse.serialization import save_to_jsonl +from lhotse.testing.dummies import DummyManifest + +from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator + + +@pytest.fixture +def nemo_manifest_path(tmp_path_factory): + """2 utterances of length 1s as a NeMo manifest.""" + tmpdir = tmp_path_factory.mktemp("nemo_data") + cuts = DummyManifest(CutSet, begin_id=0, end_id=2, with_data=True).save_audios(tmpdir, progress_bar=False) + nemo = [] + for c in cuts: + nemo.append( + { + "audio_filepath": c.recording.sources[0].source, + "text": "irrelevant", + "duration": c.duration, + "lang": "en", + } + ) + p = tmpdir / "nemo_manifest.json" + save_to_jsonl(nemo, p) + return p + + +def test_lazy_nemo_iterator(nemo_manifest_path): + cuts = CutSet(LazyNeMoIterator(nemo_manifest_path)) + + assert len(cuts) == 2 + + for c in cuts: + assert isinstance(c, MonoCut) + assert c.start == 0.0 + assert c.duration == 1.0 + assert c.num_channels == 1 + assert c.sampling_rate == 16000 + assert c.num_samples == 16000 + + assert c.has_recording + assert isinstance(c.recording, Recording) + assert c.recording.duration == 1.0 + assert c.recording.num_channels == 1 + assert c.recording.num_samples == 16000 + assert len(c.recording.sources) == 1 + assert isinstance(c.recording.sources[0], AudioSource) + assert c.recording.sources[0].type == "file" + + audio = c.load_audio() + assert isinstance(audio, np.ndarray) + assert audio.shape == (1, 16000) + assert audio.dtype == np.float32 + + assert len(c.supervisions) == 1 + s = c.supervisions[0] + assert isinstance(s, SupervisionSegment) + assert s.start == 0 + assert s.duration == 1 + assert s.channel == 0 + assert s.text == "irrelevant" + assert s.language == "en" + + +@pytest.fixture +def nemo_offset_manifest_path(tmp_path_factory): + """ + 4 utterances of length 0.5s as a NeMo manifest. + They are dervied from two audio files of 1s duration, so + two of them have offset 0 and the other two have offset 0.5. + """ + tmpdir = tmp_path_factory.mktemp("nemo_data_offset") + cuts = ( + DummyManifest(CutSet, begin_id=0, end_id=2, with_data=True) + .save_audios(tmpdir, progress_bar=False) + .cut_into_windows(duration=0.5, hop=0.5) + ) + nemo = [] + for c in cuts: + nemo.append( + { + "audio_filepath": c.recording.sources[0].source, + "text": "irrelevant", + "offset": c.start, + "duration": c.duration, + "lang": "en", + } + ) + p = tmpdir / "nemo_manifest.json" + save_to_jsonl(nemo, p) + return p + + +def test_lazy_nemo_iterator_with_offset(nemo_offset_manifest_path): + cuts = CutSet(LazyNeMoIterator(nemo_offset_manifest_path)) + + assert len(cuts) == 4 + + for idx, c in enumerate(cuts): + # Note we originally had 1 cut per 1s audio file. + # Then we cut them into 0.5s cuts, so we have 4 cuts in total, + # 2 of them start at 0s and the other 2 start at 0.5s. + is_even = idx % 2 == 0 + + assert isinstance(c, MonoCut) + if is_even: + assert c.start == 0.0 + else: + assert c.start == 0.5 + assert c.duration == 0.5 + assert c.num_channels == 1 + assert c.sampling_rate == 16000 + assert c.num_samples == 8000 + + assert c.has_recording + assert isinstance(c.recording, Recording) + assert c.recording.duration == 1.0 + assert c.recording.num_channels == 1 + assert c.recording.num_samples == 16000 + assert len(c.recording.sources) == 1 + assert isinstance(c.recording.sources[0], AudioSource) + assert c.recording.sources[0].type == "file" + + audio = c.load_audio() + assert isinstance(audio, np.ndarray) + assert audio.shape == (1, 8000) + assert audio.dtype == np.float32 + + assert len(c.supervisions) == 1 + s = c.supervisions[0] + assert isinstance(s, SupervisionSegment) + assert s.start == 0 + assert s.duration == 0.5 + assert s.channel == 0 + assert s.text == "irrelevant" + assert s.language == "en" + + +def test_lazy_nemo_iterator_with_offset_metadata_only(nemo_offset_manifest_path): + cuts = CutSet(LazyNeMoIterator(nemo_offset_manifest_path, metadata_only=True)) + + assert len(cuts) == 4 + + for idx, c in enumerate(cuts): + # Note we originally had 1 cut per 1s audio file. + # Then we cut them into 0.5s cuts, so we have 4 cuts in total, + # 2 of them start at 0s and the other 2 start at 0.5s. + is_even = idx % 2 == 0 + + assert isinstance(c, MonoCut) + if is_even: + assert c.start == 0.0 + else: + assert c.start == 0.5 + assert c.duration == 0.5 + assert c.num_channels == 1 + assert c.sampling_rate == 16000 + assert c.num_samples == 8000 + + # With metadata_only=True we can't actually check what's in the Recording. + # The metadata for it may be incorrect (but is correct for the actual Cut), + # but we don't have to perform any I/O to read the file for info. + assert c.has_recording + assert isinstance(c.recording, Recording) + if is_even: + assert c.recording.duration == 0.5 + assert c.recording.num_samples == 8000 + else: + assert c.recording.duration == 1.0 + assert c.recording.num_samples == 16000 + assert c.recording.num_channels == 1 + assert len(c.recording.sources) == 1 + assert isinstance(c.recording.sources[0], AudioSource) + assert c.recording.sources[0].type == "dummy" + + with pytest.raises(AssertionError): + c.load_audio() + + assert len(c.supervisions) == 1 + s = c.supervisions[0] + assert isinstance(s, SupervisionSegment) + assert s.start == 0 + assert s.duration == 0.5 + assert s.channel == 0 + assert s.text == "irrelevant" + assert s.language == "en" diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py new file mode 100644 index 000000000000..c78306201751 --- /dev/null +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -0,0 +1,598 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import subprocess +import sys +import tempfile +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, TypeVar, Union + +import megatron.core.num_microbatches_calculator +import pytest +import pytorch_lightning as pl +import torch +import torch.distributed +from megatron.core import ModelParallelConfig, parallel_state +from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.module import MegatronModule +from pytorch_lightning.loggers import TensorBoardLogger +from torch import Tensor, nn +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import MNIST + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning import NeMoLogger, io, resume +from nemo.lightning.megatron_parallel import DataT, MegatronLossReduction, ReductionT +from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.optim import MegatronOptimizerModule +from nemo.lightning.pytorch.plugins import MegatronDataSampler + +TokenizerType = Any + +"""This is intended to be a minimal self-container NeMo2 example.""" + + +T = TypeVar("T") + + +@dataclass +class ExampleConfig(ModelParallelConfig): + """ExampleConfig is a dataclass that is used to configure the model. + + Timers from ModelParallelConfig are required for megatron forward compatibility. + """ + + calculate_per_token_loss: bool = False + + def configure_model(self) -> nn.Module: + """This function is called by the strategy to construct the model. + + Note: Must pass self into Model since model requires having a config object. + + Returns: + The model object. + """ + return ExampleModel(self) + + +class MSELossReduction(MegatronLossReduction): + """A class used for calculating the loss, and for logging the reduced loss across micro batches.""" + + def forward(self, batch: DataT, forward_out: Tensor) -> Tuple[Tensor, ReductionT]: + """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU. + + Args: + batch: A batch of data that gets passed to the original forward inside LitAutoEncoder. + forward_out: the output of the forward method inside LitAutoEncoder. + + Returns: + A tuple containing [, ReductionT] where the loss tensor will be used for + backpropagation and the ReductionT will be passed to the reduce method + (which currently only works for logging.). + """ + x = batch["data"] + outputs = forward_out + x_hat = outputs["x_hat"] + # you could also put a latent loss on z here. + xview = x.view(x.size(0), -1) + loss = nn.functional.mse_loss(x_hat, xview) + + return loss, {"avg": loss} + + def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor: + """Works across micro-batches. (data on single gpu). + + Note: This currently only works for logging and this loss will not be used for backpropagation. + + Args: + losses_reduced_per_micro_batch: a list of the outputs of forward + + Returns: + A tensor that is the mean of the losses. (used for logging). + """ + mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch]) + return mse_losses.mean() + + +def some_first(seq: Iterable[Optional[T]]) -> T: + """Returns the first non-None value from the sequence or fails""" # noqa: D415 + for s in seq: + if s is not None: + return s + raise ValueError("non-None value not found") + + +def get_dtype_device(torch_object) -> Tuple[torch.dtype, torch.device]: # noqa: D103 + match torch_object: + case []: + raise ValueError("Looking up dtype on an empty list") + case {**data} if not data: + raise ValueError("Looking up dtype on an empty dict") + case torch.Tensor(dtype=dtype, device=device): + return dtype, device + case torch.nn.Module() as m: + try: + p = next(m.parameters()) + except StopIteration as e: + raise ValueError("Cannot get dtype on a torch module with no parameters.") from e + return p.dtype, p.device + case dict(keys=_, values=values): + val = some_first(values()) + return get_dtype_device(val) + case list() as l: + val = some_first(l) + return get_dtype_device(val) + case _: + raise TypeError("Got something we didnt expect") + + +# NOTE(SKH): These types are all wrong, but are close. The inner type must always be a torch.Tensor, but the outer container should be generic. +def batch_collator(batches: Optional[Union[Tuple[ReductionT], List[ReductionT]]]) -> Optional[ReductionT]: + """Takes a sequence of batches and collates them into a single batch. + This is distinct from the standard pytorch default_collator since it does + not add the batch dimension, it's assumed the batch + dimension is already present in the input, as would be the case when + parallelizing across minibatches. + + IMPORTANT: The underlying data primitive _must_ be a torch Tensor. The input to this function is a recurisve type, + there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d torch.Tensor. + + Examples: + Outer container = Dict: + [{'a': torch.tensor([1]), 'b': torch.tensor([2])}, {'a': torch.tensor([2]), 'b': torch.tensor([3])}] -> {'a': torch.tensor([1, 2]), 'b': torch.tensor([2, 3])} + Outer container = List: + [[torch.tensor([1]), torch.tensor([2])], [torch.tensor([2]), torch.tensor([3])]] -> [torch.tensor([1, 2]), torch.tensor([2, 3])] + Outer container = Tuple: + ([torch.tensor([1]), torch.tensor([2])], [torch.tensor([2]), torch.tensor([3])]) -> (torch.tensor([1, 2]), torch.tensor([2, 3])) + + Args: + batches (Optional[Sequence[ReductionT]]): sequence of batches to collate into a single batch. + + Returns: + A single batch of the same type as the elements of your input sequence. + """ # noqa: D205 + match batches: + case [torch.Tensor(), *_]: + return torch.cat(batches, dim=0) + case [dict(), *_]: + return {key: batch_collator([batch[key] for batch in batches]) for key in batches[0]} + case [tuple(), *_]: + return tuple(batch_collator([batch[i] for batch in batches]) for i in range(len(batches[0]))) + case [list(), *_]: + return [batch_collator([batch[i] for batch in batches]) for i in range(len(batches[0]))] + case None: + return None + case []: + raise ValueError("Cannot process an empty sequence") + case _: + raise ValueError("Unsupported input structure in batch_collator") + + +class PassthroughLossReduction(MegatronLossReduction): + """Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is expected to return a loss. + This class hijacks that mechanism to instead pass through the forward output unperturbed as the loss (to enable inference in the predict step), and then the + reduce method is used to collate the batch of forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, + or list of tensors. The inner type _must always be a torch.Tensor_. + """ # noqa: D205 + + def forward(self, batch: DataT, forward_out: DataT) -> Tuple[torch.Tensor, DataT]: + """_summary_ + + Args: + batch (DataT): The batch of data that was passed through the model to generate output. + forward_out (torch.Tensor): The output from your model's forward pass. + + Returns: + Tuple[torch.Tensor, ReductionT]: A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified). + """ # noqa: D415 + dtype, device = get_dtype_device(forward_out) + return torch.zeros(1, device=device, dtype=dtype), forward_out + + def reduce(self, forward_out: List[DataT]) -> DataT: + """This overrides the standard reduce with a simplified version that just takes a list of your model's forward outputs + and collates them togehter into a single output. + + Args: + forward_out (List[ReductionT]): _description_ + + Returns: + ReductionT: _description_ + """ # noqa: D205 + return batch_collator(forward_out) + + +class LitAutoEncoder(pl.LightningModule, io.IOMixin, io.ConnectorMixin): + """A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract.""" + + def __init__(self, config): + """Initializes the model. + + Args: + config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters). + """ + super().__init__() + self.config = config + self.optim = MegatronOptimizerModule( + config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True), + ) + # Bind the configure_optimizers method to the model + self.optim.connect(self) + + def forward(self, batch: Dict, batch_idx: Optional[int] = None) -> Any: + """This forward will be called by the megatron scheduler and it will be wrapped. + + !!! note + + The `training_step` defines the training loop and is independent of the `forward` method here. + + Args: + batch: A dictionary of data. + batch_idx: The index of the batch. + + Returns: + The output of the model. + """ + x = batch["data"] + return self.module(x) + + def training_step(self, batch, batch_idx: Optional[int] = None): + """The training step is where the loss is calculated and the backpropagation is done. + + Background: + - NeMo's Strategy overrides this method. + - The strategies' training step will call the forward method of the model. + - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model. + - That wrapped forward step is then executed inside the Mcore scheduler, which calls the `_forward_step` method from the + MegatronParallel class. + - Which then calls the training_step function here. + + In this particular use case, we simply call the forward method of this class, the lightning module. + + Args: + batch: A dictionary of data. requires `batch_idx` as default None. + batch_idx: The index of the batch. + """ + return self(batch, batch_idx) + + def training_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + # This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss + return MSELossReduction() + + def validation_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + return MSELossReduction() + + def test_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + return MSELossReduction() + + def predict_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + # This allows us to do inference (not output the loss) + return PassthroughLossReduction() + + def configure_model(self) -> None: # noqa: D102 + self.module = self.config.configure_model() + + +class ExampleModel(MegatronModule): # noqa: D101 + def __init__(self, config: ModelParallelConfig) -> None: + """Constructor of the model. + + Args: + config: The config object is responsible for telling the strategy what model to create. + """ + super().__init__(config) + self.model_type = ModelType.encoder_or_decoder + self.linear1 = nn.Linear(28 * 28, 64) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(64, 3) + self.linear3 = nn.Linear(3, 64) + self.relu2 = nn.ReLU() + self.linear4 = nn.Linear(64, 28 * 28) + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + """Forward pass of the model. + + Args: + x: The input data. + + Returns: + x_hat: The result of the last linear layer of the network. + """ + x = x.view(x.size(0), -1) + z = self.linear1(x) + z = self.relu(z) + z = self.linear2(z) + x_hat = self.linear3(z) + x_hat = self.relu2(x_hat) + x_hat = self.linear4(x_hat) + return {"x_hat": x_hat, "z": z} + + def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None: + """This is needed because it is a megatron convention. Even if it is a no-op for single GPU testing. + + See megatron.model.transformer.set_input_tensor() + + Note: Currently this is a no-op just to get by an mcore function. + + Args: + input_tensor: Input tensor. + """ + pass + + +class MnistItem(TypedDict): + data: Tensor + label: Tensor + idx: int + + +class MNISTCustom(MNIST): # noqa: D101 + def __getitem__(self, index: int) -> MnistItem: + """Wraps the getitem method of the MNIST dataset such that we return a Dict + instead of a Tuple or tensor. + + Args: + index: The index we want to grab, an int. + + Returns: + A dict containing the data ("x"), label ("y"), and index ("idx"). + """ # noqa: D205 + x, y = super().__getitem__(index) + + return { + "data": x, + "label": y, + "idx": index, + } + + +# TODO: remove this callback after `val` loss is logged by default in training in NeMo2 +class LossLoggingCallback(pl.Callback): # noqa: D101 + def __init__(self): + """Log the loss at the end of each batch. For training do not reduce across the epoch but do so for validation/test.""" + self.val_losses = [] + self.test_losses = [] + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # noqa: D102 + # Assuming the loss is computed internally and stored in pl_module + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if isinstance(outputs, dict): + outputs = outputs["loss"] + loss = outputs + pl_module.log("train_loss", loss, on_step=True, prog_bar=True, logger=True, rank_zero_only=True) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): # noqa: D102 + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if isinstance(outputs, dict): + outputs = outputs["loss"] + loss = outputs + self.test_losses.append(loss) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): # noqa: D102 + # Assuming the loss is computed internally and stored in pl_module + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if isinstance(outputs, dict): + outputs = outputs["loss"] + loss = outputs + self.val_losses.append(loss) + + def on_validation_epoch_end(self, trainer, pl_module): # noqa: D102 + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if len(self.val_losses) > 0: + avg_val_loss = torch.stack(self.val_losses).mean() + pl_module.log("val_loss", avg_val_loss, prog_bar=True, logger=True, rank_zero_only=True) + self.val_losses.clear() + + def on_test_epoch_end(self, trainer, pl_module): # noqa: D102 + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if len(self.test_losses) > 0: + avg_test_loss = torch.stack(self.test_losses).mean() + pl_module.log("test_loss", avg_test_loss, prog_bar=True, logger=True, rank_zero_only=True) + self.test_losses.clear() + + +class MNISTDataModule(pl.LightningDataModule): # noqa: D101 + def __init__(self, data_dir: str = "./", batch_size: int = 32) -> None: # noqa: D107 + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + self.micro_batch_size = 8 + self.global_batch_size = 8 + self.max_len = 100 + self.rampup_batch_size = None + + # Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler. + # Wraps the datasampler with the MegatronDataSampler. The MegatronDataSampler is a wrapper that allows the sampler + # to be used with megatron. It sets up the capability to utilize micro-batching and gradient accumulation. It is also + # the place where the global batch size is constructed. + self.data_sampler = MegatronDataSampler( + seq_len=self.max_len, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + rampup_batch_size=self.rampup_batch_size, + ) + + def setup(self, stage: str) -> None: + """Sets up the datasets + + Args: + stage: can be one of train / test / predict. + """ # noqa: D415 + self.mnist_test = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=False) + self.mnist_predict = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=False) + mnist_full = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=True) + self.mnist_train, self.mnist_val = torch.utils.data.random_split( + mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42) + ) + + def train_dataloader(self) -> DataLoader: # noqa: D102 + return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=0) + + def val_dataloader(self) -> DataLoader: # noqa: D102 + return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=0) + + def test_dataloader(self) -> DataLoader: # noqa: D102 + return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=0) + + +### Begin model environment related utilities +def _reset_megatron_parallel_state(): + """Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initialized model parallel in + nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo + """ # noqa: D205, D415 + megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None + # Clean up any process groups created in testing + torch.cuda.empty_cache() + if parallel_state.is_initialized(): + parallel_state.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +@contextmanager +def reset_megatron_parallel_state() -> Iterator[None]: + """Puts you into a clean parallel state, and again tears it down at the end.""" + try: + _reset_megatron_parallel_state() + yield + finally: + _reset_megatron_parallel_state() + + +@pytest.mark.run_only_on("GPU") +@pytest.mark.integration +def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): + path = os.path.abspath(__file__) + call = f"python {path}" + # Raises a CalledProcessError if there is a failure in the subprocess + subprocess.check_call(call, shell=True, stdout=sys.stdout, stderr=sys.stdout) + + +def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): + """This is the actual test that will get run in a subprocess so it does not contaminate the state of other tests.""" + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + assert tmpdir.exists() + assert tmpdir.is_dir() + with reset_megatron_parallel_state(): + # Configure our custom Checkpointer + name = "test_experiment" + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_best_model=True, + save_last=True, + monitor="val_loss", + save_top_k=1, + every_n_train_steps=5, + # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + enable_nemo_ckpt_io=True, + ) + root_dir = tmpdir + save_dir = root_dir / name + tb_logger = TensorBoardLogger(save_dir=str(save_dir), name=name) + # Setup the logger and train the model + nemo_logger = NeMoLogger( + dir=str(root_dir), # WARNING: passing a path in here results in mutating the Path class. + name=name, + tensorboard=tb_logger, + ckpt=checkpoint_callback, + ) + # Needed so that the trainer can find an output directory for the profiler + # nemo_logger.save_dir = tmpdir + + model = LitAutoEncoder(config=ExampleConfig()) + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ddp="megatron", + find_unused_parameters=True, + enable_nemo_ckpt_io=True, + ) + trainer = nl.Trainer( + accelerator="gpu", + devices=1, + strategy=strategy, + limit_val_batches=5, + val_check_interval=5, + max_steps=20, + num_nodes=1, + log_every_n_steps=5, + callbacks=[io.track_io(LossLoggingCallback)()], + ) + data_module = MNISTDataModule(data_dir=tmpdir) + llm.train( + model=model, + data=data_module, + trainer=trainer, + log=nemo_logger, + resume=resume.AutoResume( + path=None, # Overrides the path found by resume_if_exists when set. + resume_if_exists=True, # Looks for the -last checkpoint to continue training. + resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. + ), + ) + trainer._teardown() + with reset_megatron_parallel_state(): + pred_strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ddp="megatron", + find_unused_parameters=True, + enable_nemo_ckpt_io=True, + data_sampler=MegatronDataSampler( + seq_len=28 * 28, + micro_batch_size=2, + global_batch_size=2, + output_log=False, # Disable logs to support predict_step + ), + ) + predict_trainer = nl.Trainer( + accelerator="gpu", + devices=1, + strategy=pred_strategy, + default_root_dir=str(root_dir), # WARNING: passing a path in here results in mutating the Path class. + ) + ckpt_path = checkpoint_callback.last_model_path.replace( + ".ckpt", "" + ) # strip .ckpt off the end of the last path + + assert Path( + ckpt_path + ).exists(), f"checkpoint {ckpt_path} not found in {os.listdir(Path(ckpt_path).parent)}" + # FIXME: the below checkpoint loading strategy and manual module unwrapping probably only works in single GPU + # and maybe DDP. + unwrapped_trained_model = trainer.model.module # TODO clean this up. Would be good not to have to unwrap. + forward_output = batch_collator( + predict_trainer.predict( + unwrapped_trained_model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path + ) + ) + assert set(forward_output.keys()) == { + "z", + "x_hat", + }, f"We expect forward output from predit_step, not the loss, got: {forward_output}" + assert forward_output["x_hat"].shape == (len(data_module.mnist_test), 28 * 28) + assert forward_output["z"].shape == (len(data_module.mnist_test), 3) # latent bottleneck in model of dim 3 + predict_trainer._teardown() + + +if __name__ == "__main__": + # Have the test run this one item as a subprocess call + run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu() diff --git a/tests/lightning/test_ddp_parity_checker.py b/tests/lightning/test_ddp_parity_checker.py new file mode 100644 index 000000000000..7d180ba17dfe --- /dev/null +++ b/tests/lightning/test_ddp_parity_checker.py @@ -0,0 +1,129 @@ +import argparse +import os + +import pytest +import torch +from megatron.core.optimizer import OptimizerConfig + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.llm.gpt.data import PreTrainingDataModule +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning.pytorch.callbacks import DdpParityChecker + + +def make_parser(): + parser = argparse.ArgumentParser(description='Train a small GPT model using NeMo 2.0') + parser.add_argument('--data-path', type=str, help="Path to data file") + parser.add_argument('--vocab-path', type=str, help="Path to vocab file") + parser.add_argument('--merges-path', type=str, help="Path to merges file") + + return parser + + +def wrap_config(config, trainer): + class ConfigWrapper(type(config)): + def configure_model(self, tokenizer) -> "MCoreGPTModel": + return make_byzantine_model_wrapper(super().configure_model(tokenizer), trainer) + + config.__class__ = ConfigWrapper + return config + + +def make_byzantine_model_wrapper(model, trainer): + class ByzantineModel(type(model)): + def forward(self, *ans, **kwargs): + ans = super().forward(*ans, **kwargs) + with torch.no_grad(): + import random + + rank = int(os.environ['LOCAL_RANK']) + if rank != 1: + return ans + for opt in trainer.strategy.model.optim._optimizers: + for g in opt.param_groups: + for param in g['params']: + param.fill_(random.uniform(0, 1)) + return ans + + model.__class__ = ByzantineModel + return model + + +@pytest.mark.skip(reason="tested with GH") +def test_failing(trainer, ddp_parity, optim, data, tokenizer): + config = llm.Llama2Config7B(num_layers=2) + config = wrap_config(config, trainer) + model = llm.LlamaModel(config, tokenizer=tokenizer, optim=optim) + trainer.fit(model, data) + + +@pytest.mark.skip(reason="tested with GH") +def test_working(trainer, ddp_parity, optim, data, tokenizer): + config = llm.Llama2Config7B(num_layers=2) + model = llm.LlamaModel(config, tokenizer=tokenizer, optim=optim) + trainer.fit(model, data) + + +def make_trainer_optim(args): + ddp_parity = DdpParityChecker(1) + trainer = nl.Trainer( + devices=2, + max_steps=4, + accelerator="gpu", + strategy=nl.MegatronStrategy( + ckpt_include_optimizer=False, + ), + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + limit_val_batches=1, + num_sanity_val_steps=0, + log_every_n_steps=1, + logger=None, + callbacks=[ddp_parity], + ) + + optim = nl.MegatronOptimizerModule( + config=OptimizerConfig( + optimizer="adam", + lr=1e-5, + use_distributed_optimizer=False, + fp16=False, + bf16=True, + params_dtype=torch.float32, + ), + ) + + tokenizer = get_nmt_tokenizer( + "megatron", + "GPT2BPETokenizer", + vocab_file=args.vocab_path, + merges_file=args.merges_path, + ) + data = PreTrainingDataModule( + paths=args.data_path, + seq_length=2048, + global_batch_size=32, + seed=1234, + tokenizer=tokenizer, + ) + + return trainer, ddp_parity, optim, data, tokenizer + + +@pytest.mark.skip(reason="tested with GH") +def main(): + args = make_parser().parse_args() + trainer, ddp_parity, optim, data, tokenizer = make_trainer_optim(args) + test_failing(trainer, ddp_parity, optim, data, tokenizer) + if trainer.should_stop != True: + raise ValueError("DDP parity checking failed.") + + try: + test_working(*make_trainer_optim(args)) + print("DDP parity checking worked as expected") + except: + raise + + +if __name__ == "__main__": + main() diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index a0a16150c65f..955367cb7581 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -21,7 +21,7 @@ def test_loggers(self): trainer = nl.Trainer(accelerator="cpu") logger = nl.NeMoLogger( update_logger_directory=True, - wandb=WandbLogger(save_dir="wandb_logs", offline=True), + wandb=WandbLogger(name="custom", save_dir="wandb_logs", offline=True), ) logger.setup(trainer) @@ -30,7 +30,7 @@ def test_loggers(self): assert len(trainer.loggers) == 2 assert isinstance(trainer.loggers[1], WandbLogger) assert str(trainer.loggers[1].save_dir).endswith("nemo_experiments/wandb_logs") - assert trainer.loggers[1]._name == "default" + assert trainer.loggers[1]._name == "custom" def test_explicit_log_dir(self, trainer): explicit_dir = "explicit_test_dir" diff --git a/tutorials/multimodal/SDXL Tutorial.ipynb b/tutorials/multimodal/SDXL Tutorial.ipynb new file mode 100644 index 000000000000..92667100b405 --- /dev/null +++ b/tutorials/multimodal/SDXL Tutorial.ipynb @@ -0,0 +1,253 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "d874e23f-9631-48e0-b635-84e7280bf07b", + "metadata": {}, + "source": [ + "# SDXL Training / Inference Tutorial\n", + "\n", + "### Note:\n", + "Currently, this notebook must be run in a NeMo container (> 24.09) and open_clip_torch<=2.24.0. An example command to launch the container:\n", + "\n", + "```\n", + "docker run --gpus all -it --rm -v :/opt/NeMo -v :/datasets --shm-size=8g \\\n", + " -p 8888:8888 --ulimit memlock=-1 --ulimit \\\n", + " stack=67108864 \n", + "```\n", + "\n", + "\n", + "## Introduction\n", + "\n", + "This notebook illustrates how to train and perform inference using Stable Diffusion XL with the NeMo Toolkit. Despite differences in model configs, the training and inference procedure is similar as Stable Diffusion.\n", + "\n", + "The implementation of Stable Diffusion XL is based on [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/abs/2307.01952).\n", + "\n", + "This tutorial will guide you through the following topics:\n", + "\n", + "1. Training a Stable Diffusion XL model.\n", + "2. Performing inference with the trained model.\n", + "\n", + "## Datasets\n", + "\n", + "Please refer to [Dataset Tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/multimodal/Multimodal%20Data%20Preparation.ipynb) for how to prepare a training dataset for Stable diffusion XL.\n", + "\n", + "For a pre-cached Stable Diffusion dataset, each webdataset tar file should, at a minimum, include the pickle files that store the pre-cached image and text features:\n", + "\n", + "```\n", + "t0_r0_0.tar\n", + "|---- 0000.pickle\n", + "|---- 0001.pickle\n", + "...\n", + "```\n", + "\n", + "For non-precached Stable Diffusion dataset, each webdataset tar file should contain the raw texts and corresponding images:\n", + "\n", + "```\n", + "t0_r0_0.tar\n", + "|---- 0000.jpg\n", + "|---- 0000.txt\n", + "|---- 0001.jpg\n", + "|---- 0001.txt\n", + "...\n", + "```\n", + "\n", + "## Encoders Preparation\n", + "\n", + "Depending on whether you precache the dataset, you might also need to first download the image and/or text encoders.\n", + "\n", + "### Option 1: Training on Non-Precached Dataset (Use Encoders During Training)\n", + "\n", + "#### A. Prepare VAE\n", + "To download the default VAE for Stable Diffusion:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "730cd137-0fce-4bab-8ac7-219e5c55faf2", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "! wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/vae/diffusion_pytorch_model.safetensors\n", + "! mkdir -p /sdxl_ckpts\n", + "! mv diffusion_pytorch_model.safetensors /sdxl_ckpts/vae.safetensors" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "fef8b245-7cee-4048-a9ec-3ada90432a89", + "metadata": {}, + "source": [ + "The above command will download the default VAE weights from HuggingFace and save it to `/sdxl_ckpts/vae.safetensors`.\n", + "\n", + "**Note**: if you want to customize the saved location, make sure it is also reflected in your training config.\n", + "#### B. Prepare Text Encoder\n", + "For the text encoders used in Stable Diffusion XL, it will be automatically downloaded by the training script we provide.\n", + "\n", + "The type of text encoder used in the sdxl model conditioner can be found in `conditioner_config` in the predefined training configs:\n", + "\n", + "```\n", + " conditioner_config:\n", + " _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner\n", + " emb_models:\n", + " - is_trainable: false\n", + " input_key: captions\n", + " ucg_rate: 0.1\n", + " emb_model:\n", + " _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder\n", + " layer: hidden\n", + " layer_idx: 11\n", + " - is_trainable: false\n", + " ucg_rate: 0.1\n", + " input_key: captions\n", + " emb_model:\n", + " _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder2\n", + " arch: ViT-bigG-14\n", + " version: laion2b_s39b_b160k\n", + " freeze: true\n", + " layer: penultimate\n", + " always_return_pooled: true\n", + " legacy: false\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8854eb7a-e822-43f6-a1d5-12357049485a", + "metadata": {}, + "source": [ + "\n", + "### Option 2: Training on Precached Dataset (Training UNet Only)\n", + "\n", + "When using precached dataset (please refer to the [Dataset Tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/multimodal/Multimodal%20Data%20Preparation.ipynb) for details), every text feature and image feature are stored as key-value pairs in `.pickle` file:\n", + "\n", + "```\n", + "{\n", + " image_key: torch.Tensor(),\n", + " text_key: torch.Tensor(),\n", + "}\n", + "```\n", + "\n", + "Make sure in the training config, `cond_stage_key` is associated with `text_key` and `first_stage_key` is associated with `image_key`.\n", + "\n", + "We offer an expample script to convert a dataset from `parquet` file to webdataset `tar` files at [parquet_conversion](https://github.com/NVIDIA/NeMo/blob/main/scripts/multimodal_dataset_conversion/parquet_conversion.py). Three different modes of prechaed training are provided, they are:\n", + "\n", + "1. No Caching: VAE and Text encoders are loaded during training\n", + "2. Text only: Only text features are loaded from dataset during training\n", + "3. Both: Both image and text features are loaded from dataset during training\n", + "\n", + "In each mode, the cached components should be saved in its raw format in tarfiles while cached components should be saved as torch.Tensor()." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "5762427b-f60c-4dfd-8318-e55771b25354", + "metadata": {}, + "source": [ + "## Model Config Setup\n", + "\n", + "Now we will begin setting up the config file needed for Stable Diffusion training. We will use [sd_train.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml) as the template.\n", + "\n", + "1. Modify `model.data.train.dataset_path` so that it has all the webdataset info files you want to train on\n", + "2. Modify `model.data.webdataset.local_root_path` to point to your dataset path\n", + "3. Make sure VAE path `model.first_stage_config.from_pretrained` is adjusted if using non-precached dataset\n", + "4. Make sure the `model.precache mode` is set properly with the dataset you prepared, as detailed above.\n", + "5. Configure `exp_manager.exp_dir` for experiment save directory\n", + "6. Configure `exp_manager.wandb_logger_kwargs` and/or `exp_manager.create_tensorboard_logger` if needed" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "70f858b3-f7d5-4678-b380-80582337bc23", + "metadata": {}, + "source": [ + "**Note**: Please refer to NeMo Toolkit Developer Guide's Stable Diffusion page for more details on in-depth customizations, including all available optimizations.\n", + "\n", + "## Training\n", + "\n", + "Once everything is set up, training stable diffusion is as simple as running:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "589e3a14-c881-4a56-b2bd-370653059dfc", + "metadata": {}, + "outputs": [], + "source": "! torchrun /opt/NeMo/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py trainer.max_steps=100 model.data.train.dataset_path=/path/to/wdinfo.pkl model.data.webdataset.local_root_path=/path/to/dataset trainer.devices=1 trainer.num_nodes=1 model.micro_batch_size=1 model.global_batch_size=1 model.first_stage_config.from_pretrained=/sdxl_ckpts/vae.safetensors model.fsdp=False" + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "892d72dd-c4d7-4ca4-a948-168e187af65c", + "metadata": {}, + "source": [ + "Intermediate checkpoints (during training) and final checkpoint will be saved to `exp_manager.exp_dir` folder. Note that here we use synthetic data for demo purpose." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "087c8b9a-92c3-43d3-86a3-bf7e848dfbd2", + "metadata": {}, + "source": [ + "## Inference\n", + "\n", + "Stable Diffusion XL inference needs a trained NeMo Stable Diffusion checkpoint, along with both the image encoder (VAE) and text encoder (CLIP). The checkpoint can be either a fully trained `.nemo` checkpoint or an intermediate checkpoint from training (typically in `.ckpt` format). \n", + "\n", + "### Inference Config Setup\n", + "\n", + "Now we will begin setting up the config file needed for Stable Diffusion inference. We will use [sd_xl_infer_v2.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml) as the template.\n", + "\n", + "We generally use [Classifier Free Guidance](https://arxiv.org/abs/2207.12598) for better visual quality, which can be set at `sampling.base.scale`.\n", + "\n", + "NeMo Stable Diffusion supports multiple samplers. Please refer to the developer guide for more details. Samplers can be set at `sampling.base.sampler`.\n", + "\n", + "Inference supports a batch of text prompts, which can be set at `infer.prompt`. One can also generate a configurable number of images per prompt by setting `infer.num_samples`. Generated images will be saved to `out_path`.\n", + "\n", + "You will also need to set the model checkpoint path at `model.restore_from_path` if you are loading from `.nemo` checkpoint, otherwise, mannually set `unet` checkpoints and `vae` checkpoint at `model.unet_config.from_pretrained` and `model.first_stage_config.from_pretrained`, respectively.\n", + "\n", + "### Running the Inference\n", + "\n", + "Once everything is set up, Stable Diffusion inference is as simple as running:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e676c5d-d711-489e-8ab7-3ee20046d88d", + "metadata": {}, + "outputs": [], + "source": "! torchrun /opt/NeMo/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py model.restore_from_path=/path/to/stable-diffusion-xl-train.nemo out_path=/sdxl_infer_out" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/tts/Audio_Codec_Inference.ipynb b/tutorials/tts/Audio_Codec_Inference.ipynb new file mode 100644 index 000000000000..8eff02916737 --- /dev/null +++ b/tutorials/tts/Audio_Codec_Inference.ipynb @@ -0,0 +1,478 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "7X-TwhdTGmlc" + }, + "source": [ + "# License" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fCQUeZRPGnoe" + }, + "source": [ + "> Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n", + ">\n", + "> Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", + ">\n", + "> http://www.apache.org/licenses/LICENSE-2.0\n", + ">\n", + "> Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rtBDkKqVGZJ8" + }, + "source": [ + "# Introduction" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pZ2QSsXuGbMe" + }, + "source": [ + "In this tutorial we show how use NeMo **neural audio codecs** at inference time. To learn more about training and finetuning neural audio codecs in NeMo, check the [Audio Codec Training tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tts/Audio_Codec_Training.ipynb).\n", + "\n", + "An audio codec typically consists of an encoder, a quantizer and a decoder, with a typical architecture depicted in the figure below.\n", + "An audio codec can be used to encode an input audio signal into a sequence of discrete values.\n", + "In this tutorial, the discrete values will be referred to as **audio tokens**.\n", + "The obtained audio tokens can be decoded into an output audio signal.\n", + "\n", + "Audio tokens can be used to represent the input audio for an automatic speech recognition (ASR) model [[1](https://arxiv.org/abs/2309.10922), [2](https://arxiv.org/pdf/2407.03495)], or to represent the output audio of a text-to-speech (TTS) system [[3](https://arxiv.org/abs/2406.05298), [4](https://arxiv.org/pdf/2406.17957)].\n", + "\n", + "NeMo provides several neural audio codec models, inlcuding audio codecs and mel codecs at different sampling rates.\n", + "The list of the available models can be found [here](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/tts/checkpoints.html#codec-models).\n", + "\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3OZassNG5xff" + }, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WZvQvPkIhRi3" + }, + "outputs": [], + "source": [ + "BRANCH = 'main'\n", + "# Install NeMo library. If you are running locally (rather than on Google Colab), follow the instructions at https://github.com/NVIDIA/NeMo#Installation\n", + "\n", + "if 'google.colab' in str(get_ipython()):\n", + " !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "v8NGOM0EzK8W" + }, + "outputs": [], + "source": [ + "import math\n", + "import wget\n", + "import os\n", + "import librosa\n", + "import torch\n", + "import numpy as np\n", + "import IPython.display as ipd\n", + "import matplotlib.pyplot as plt\n", + "from pathlib import Path\n", + "\n", + "\n", + "# Utility for displaying signals and metrics\n", + "def show_signal(signal: np.ndarray, sample_rate: int = 16000, tag: str = 'Signal'):\n", + " \"\"\"Show the time-domain signal and its spectrogram.\n", + " \"\"\"\n", + " fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 2.5))\n", + "\n", + " # show waveform\n", + " t = np.arange(0, len(signal)) / sample_rate\n", + "\n", + " ax[0].plot(t, signal)\n", + " ax[0].set_xlim(0, t.max())\n", + " ax[0].grid()\n", + " ax[0].set_xlabel('time / s')\n", + " ax[0].set_ylabel('amplitude')\n", + " ax[0].set_title(tag)\n", + "\n", + " n_fft = 1024\n", + " hop_length = 256\n", + "\n", + " D = librosa.amplitude_to_db(np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length)), ref=np.max)\n", + " img = librosa.display.specshow(D, y_axis='linear', x_axis='time', sr=sample_rate, n_fft=n_fft, hop_length=hop_length, ax=ax[1])\n", + " ax[1].set_title(tag)\n", + "\n", + " plt.tight_layout()\n", + " plt.colorbar(img, format=\"%+2.f dB\", ax=ax)\n", + "\n", + "\n", + "# Utility for displaying a latent representation\n", + "def show_latent(latent: np.ndarray, tag: str):\n", + " plt.figure(figsize = (16, 3))\n", + " img = plt.imshow(latent, aspect='equal')\n", + " plt.colorbar(img, ax=plt.gca())\n", + " plt.title(tag)\n", + " plt.xlabel('Time frame')\n", + " plt.ylabel('Latent vector index')\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8ZKDMTwsEY1K" + }, + "outputs": [], + "source": [ + "# Working directory\n", + "ROOT_DIR = Path().absolute() / 'codec_tutorial'\n", + "\n", + "# Create dataset directory\n", + "DATA_DIR = ROOT_DIR / 'data'\n", + "DATA_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "audio_path = DATA_DIR / 'LJ023-0089.wav'\n", + "audio_url = \"https://multilangaudiosamples.s3.us-east-2.amazonaws.com/LJ023-0089.wav\"\n", + "\n", + "if not os.path.exists(audio_path):\n", + " wget.download(audio_url, audio_path.as_posix())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KAbH7N427FdT" + }, + "source": [ + "# Load a model from NGC" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ODgdGgsAAUku" + }, + "source": [ + "Any of the [pretrained checkpoints](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/tts/checkpoints.html#codec-models) could be used for inference.\n", + "Here, we use `mel_codec_22khz_fullband_medium`, which works for 22.05 kHz audio signals.\n", + "\n", + "The model can be easily restored from NGC:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XqAYWR65aKTx" + }, + "outputs": [], + "source": [ + "from nemo.collections.tts.models.audio_codec import AudioCodecModel\n", + "\n", + "# Optionally specify a pretrained model to fine-tune from. To train from scratch, set this to 'None'.\n", + "model_name = 'mel_codec_22khz_fullband_medium'\n", + "codec_model = AudioCodecModel.from_pretrained(model_name)\n", + "codec_model.freeze()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZnnjL28pEY1L" + }, + "source": [ + "Show information about the loaded model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4xsfeHVyEY1L" + }, + "outputs": [], + "source": [ + "print(f'Loaded model from NeMo:')\n", + "print(f'\\tmodel name : {model_name}')\n", + "print(f'\\tsample rate : {codec_model.sample_rate} Hz')\n", + "print(f'\\tlatent dimension : {codec_model.vector_quantizer.codebook_dim}')\n", + "\n", + "print('\\n\\nModel summary:')\n", + "print(codec_model.summarize())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fM4QPsLTnzK7" + }, + "source": [ + "# Inference" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tkZC6Dl7KRl6" + }, + "source": [ + "## Processing audio\n", + "\n", + "Here we use the codec model to process the input audio by applying the complete model. The input signal is encoded, quantized, dequantized and decoded. Finally, a reconstructed signal is obtained." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sYzvAYr2vo1K" + }, + "outputs": [], + "source": [ + "input_audio, sr = librosa.load(audio_path, sr=codec_model.sample_rate)\n", + "\n", + "# Shape (batch, time)\n", + "input_audio_tensor = torch.from_numpy(input_audio).unsqueeze(dim=0).to(codec_model.device)\n", + "\n", + "# Shape (batch,)\n", + "input_audio_len = torch.tensor([input_audio_tensor.size(-1)]).to(codec_model.device)\n", + "\n", + "# Process audio using the codec model\n", + "output_audio_tensor, _ = codec_model(audio=input_audio_tensor, audio_len=input_audio_len)\n", + "\n", + "# Output audio\n", + "output_audio = output_audio_tensor.squeeze().cpu().numpy()\n", + "\n", + "# Show signals\n", + "show_signal(input_audio, tag='Input audio', sample_rate=codec_model.sample_rate)\n", + "show_signal(output_audio, tag='Output audio', sample_rate=codec_model.sample_rate)\n", + "\n", + "# Play audio\n", + "print('Input audio')\n", + "ipd.display(ipd.Audio(input_audio, rate=codec_model.sample_rate))\n", + "\n", + "print('Output audio')\n", + "ipd.display(ipd.Audio(output_audio, rate=codec_model.sample_rate))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rynZYwg2VP5d" + }, + "source": [ + "## Audio tokens\n", + "\n", + "Audio tokens can be easily computed by using the `encode` method of the `AudioCodec` model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ASKM_jKVEY1L" + }, + "outputs": [], + "source": [ + "# Convert audio to tokens\n", + "tokens, tokens_len = codec_model.encode(audio=input_audio_tensor, audio_len=input_audio_len)\n", + "\n", + "print('tokens information:')\n", + "print(f'\\tshape (batch, codebook, time frame) : {tokens.size()}')\n", + "print(f'\\tdtype : {tokens.dtype}')\n", + "print(f'\\tmin : {tokens.min()}')\n", + "print(f'\\tmax : {tokens.max()}')\n", + "\n", + "# Number of codebooks should match the number of codebooks/groups\n", + "if hasattr(codec_model.vector_quantizer, 'num_groups'):\n", + " # Group FSQ\n", + " assert tokens.size(1) == codec_model.vector_quantizer.num_groups\n", + " print(f'\\tnum_groups : {tokens.size(1)}')\n", + "elif hasattr(codec_model.vector_quantizer, 'codebooks'):\n", + " # RVQ\n", + " assert tokens.size(1) == len(codec_model.vector_quantizer.codebooks)\n", + " print(f'\\tnum_codebooks : {tokens.size(1)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CmliPMnDEY1L" + }, + "source": [ + "Similarly, audio can be easily reconstructed from audio tokens using the `decode` method of the `AudioCodec` models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RTQ1M9PMEY1L" + }, + "outputs": [], + "source": [ + "# Convert tokens back to audio\n", + "output_audio_from_tokens_tensor, _ = codec_model.decode(tokens=tokens, tokens_len=tokens_len)\n", + "output_audio_from_tokens = output_audio_from_tokens_tensor.squeeze().cpu().numpy()\n", + "\n", + "# Show signals\n", + "show_signal(output_audio_from_tokens, tag='Output audio from tokens', sample_rate=codec_model.sample_rate)\n", + "show_signal(output_audio_from_tokens - output_audio, tag='Difference compared to forward pass', sample_rate=codec_model.sample_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kGqotZkqEY1M" + }, + "source": [ + "## Latent representation\n", + "\n", + "Continuous (non-discrete) latent representation at the output of the encoder can be easily computed using the `encode_audio` method of the `AudioCodec` model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r-89-gG3EY1M" + }, + "outputs": [], + "source": [ + "# Convert audio to the encoded representation\n", + "encoded, encoded_len = codec_model.encode_audio(audio=input_audio_tensor, audio_len=input_audio_len)\n", + "\n", + "print('encoded information:')\n", + "print(f'\\tshape (batch, codebook, time frame) : {encoded.size()}')\n", + "print(f'\\tdtype : {encoded.dtype}')\n", + "print(f'\\tmin : {encoded.min()}')\n", + "print(f'\\tmax : {encoded.max()}')\n", + "\n", + "\n", + "# Show the encoded representation\n", + "show_latent(encoded.squeeze().cpu().numpy(), tag='Encoder output')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3Ory1U1uEY1M" + }, + "source": [ + "The encoded representation can be easily converted to tokens, dequantized into a continuous latent representation and decoded back to audio." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "btmqUWNkEY1M" + }, + "outputs": [], + "source": [ + "# Encoder output to tokens\n", + "tokens = codec_model.quantize(encoded=encoded, encoded_len=encoded_len)\n", + "\n", + "# Tokens back to a continuous vector\n", + "dequantized = codec_model.dequantize(tokens=tokens, tokens_len=encoded_len)\n", + "\n", + "# Reconstruct audio\n", + "output_audio_from_latent_tensor, _ = codec_model.decode_audio(inputs=dequantized, input_len=encoded_len)\n", + "output_audio_from_latent = output_audio_from_latent_tensor.squeeze().cpu().numpy()\n", + "\n", + "# Show dequantized latent representation\n", + "show_latent(dequantized.squeeze().cpu().numpy(), tag='Decoder input')\n", + "\n", + "# Show signals\n", + "show_signal(output_audio_from_latent, tag='Output audio from latent', sample_rate=codec_model.sample_rate)\n", + "show_signal(output_audio_from_latent - output_audio, tag='Difference compared to forward pass', sample_rate=codec_model.sample_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cMvU0WxlEY1M" + }, + "source": [ + "# Related information" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_LtyHHuLkNDv" + }, + "source": [ + "To learn more about audio codec models in NeMo, look at our [documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/tts/models.html#codecs).\n", + "\n", + "For more information on training and finetuning neural audio codecs in NeMo, check the [Audio Codec Training tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tts/Audio_Codec_Training.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LeqV3VvJVOb-" + }, + "source": [ + "# References" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Rvu4w2x_3RSY" + }, + "source": [ + "1. [Discrete Audio Representation as an Alternative to Mel-Spectrograms for Speaker and Speech Recognition](https://arxiv.org/abs/2309.10922)\n", + "2. [Codec-ASR: Training Performant Automatic Speech Recognition Systems with Discrete Speech Representations](https://arxiv.org/pdf/2407.03495)\n", + "3. [Spectral Codecs: Spectrogram-Based Audio Codecs for High Quality Speech Synthesis](https://arxiv.org/abs/2406.05298)\n", + "4. [Improving Robustness of LLM-based Speech Synthesis by Learning Monotonic Alignment](https://arxiv.org/pdf/2406.17957)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "colab": { + "provenance": [], + "toc_visible": true + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file