diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index de250596da62..6f090bd34213 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -132,6 +132,9 @@ jobs: apt-get update && apt-get install libsox-fmt-all -y && \ popd + # AMMO installation + pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir + # PyTorch Lightning version python -c "import pytorch_lightning; print(pytorch_lightning.__version__)" @@ -220,7 +223,26 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" - + L0_Setup_Test_Data_And_Models: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python -m tests.setup --save_dir /home/TestData/nlp + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" ## - name: L2: Multimodal Imagen Train @@ -243,10 +265,9 @@ jobs: uses: actions/checkout@v4 - run: | CUDA_VISIBLE_DEVICES=0 python scripts/checkpoint_converters/convert_llama_hf_to_nemo.py \ - --input_name_or_path=/home/TestData/nlp/megatron_llama/llama-ci-hf \ - --output_path=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \ + --input_name_or_path=/home/TestData/nlp/megatron_llama/llama-ci-hf-tiny \ + --output_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ --precision=16 - rm -f /home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" @@ -322,6 +343,124 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" + L2_PTQ_Llama2_Export_Only: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + quantization.algorithm=null \ + model_save=/home/TestData/nlp/megatron_llama/ci_baseline + + rm -rf /home/TestData/nlp/megatron_llama/ci_baseline + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + + L2_PTQ_Llama2_FP8: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + tensor_model_parallel_size=2 \ + trainer.devices=2 \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=fp8 \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + export.inference_tensor_parallel=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo + + rm -rf /home/TestData/nlp/megatron_llama/ci_fp8.qnemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + + L2_PTQ_Llama2_INT8_SQ: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=int8_sq \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo + + rm -rf /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + + L2_PTQ_Llama2_INT4_AWQ: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + tensor_model_parallel_size=1 \ + trainer.devices=1 \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=int4_awq \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_int4_awq.qnemo + + rm -rf /home/TestData/nlp/megatron_llama/ci_int4_awq.qnemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + # L2: ASR dev run ASR_dev_run_Speech_to_Text: needs: [cicd-test-container-setup] @@ -4664,7 +4803,7 @@ jobs: --volume /mnt/datadrive/TestData:/home/TestData steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - run: | rm -rf /home/TestData/nlp/megatron_ir/working_dir diff --git a/README.rst b/README.rst index 66b3a5806c2d..0b05bd0390f8 100644 --- a/README.rst +++ b/README.rst @@ -77,6 +77,31 @@ Latest News +
+ Speech Recognition +
+ New Standard for Speech Recognition and Translation from the NVIDIA NeMo Canary Model (2024/04/18) + + The NeMo team just released Canary, a multilingual model that transcribes speech in English, Spanish, German, and French with punctuation and capitalization. Canary also provides bi-directional translation, between English and the three other supported languages. +

+
+ +
+ Pushing the Boundaries of Speech Recognition with NVIDIA NeMo Parakeet ASR Models (2024/04/18) + + NVIDIA NeMo, an end-to-end platform for the development of multimodal generative AI models at scale anywhere—on any cloud and on-premises—released the Parakeet family of automatic speech recognition (ASR) models. These state-of-the-art ASR models, developed in collaboration with Suno.ai, transcribe spoken English with exceptional accuracy. +

+
+ +
+ Turbocharge ASR Accuracy and Speed with NVIDIA NeMo Parakeet-TDT (2024/04/18) + + NVIDIA NeMo, an end-to-end platform for developing multimodal generative AI models at scale anywhere—on any cloud and on-premises—recently released Parakeet-TDT. This new addition to the  NeMo ASR Parakeet model family boasts better accuracy and 64% greater speed over the previously best model, Parakeet-RNNT-1.1B. +

+
+ +
+ diff --git a/examples/audio_tasks/audio_to_audio_eval.py b/examples/audio_tasks/audio_to_audio_eval.py index 4ac68dfc84e7..ab6623df298d 100644 --- a/examples/audio_tasks/audio_to_audio_eval.py +++ b/examples/audio_tasks/audio_to_audio_eval.py @@ -61,6 +61,7 @@ import json import os import tempfile +from collections import defaultdict from dataclasses import dataclass, field, is_dataclass from typing import List, Optional @@ -101,6 +102,9 @@ class AudioEvaluationConfig(process_audio.ProcessConfig): # Metrics to calculate metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi']) + # Return metric values for each example + return_values_per_example: bool = False + def get_evaluation_dataloader(config): """Prepare a dataloader for evaluation. @@ -174,6 +178,9 @@ def main(cfg: AudioEvaluationConfig): # Setup metrics metrics = get_metrics(cfg) + if cfg.return_values_per_example and cfg.batch_size > 1: + raise ValueError('return_example_values is only supported for batch_size=1.') + # Processing if not cfg.only_score_manifest: # Process audio using the configured model and save in the output directory @@ -236,6 +243,10 @@ def main(cfg: AudioEvaluationConfig): num_files += 1 + if cfg.max_utts is not None and num_files >= cfg.max_utts: + logging.info('Reached max_utts: %s', cfg.max_utts) + break + # Prepare dataloader config = { 'manifest_filepath': temporary_manifest_filepath, @@ -249,6 +260,8 @@ def main(cfg: AudioEvaluationConfig): } temporary_dataloader = get_evaluation_dataloader(config) + metrics_value_per_example = defaultdict(list) + # Calculate metrics for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'): processed_signal, processed_length, target_signal, target_length = eval_batch @@ -257,7 +270,9 @@ def main(cfg: AudioEvaluationConfig): raise RuntimeError(f'Length mismatch.') for name, metric in metrics.items(): - metric.update(preds=processed_signal, target=target_signal, input_length=target_length) + value = metric(preds=processed_signal, target=target_signal, input_length=target_length) + if cfg.return_values_per_example: + metrics_value_per_example[name].append(value.item()) # Convert to a dictionary with name: value metrics_value = {name: metric.compute().item() for name, metric in metrics.items()} @@ -277,6 +292,7 @@ def main(cfg: AudioEvaluationConfig): # Inject the metric name and score into the config, and return the entire config with open_dict(cfg): cfg.metrics_value = metrics_value + cfg.metrics_value_per_example = dict(metrics_value_per_example) return cfg diff --git a/examples/audio_tasks/conf/beamforming.yaml b/examples/audio_tasks/conf/beamforming.yaml index 18e04f0bd12a..3abc4f134e64 100644 --- a/examples/audio_tasks/conf/beamforming.yaml +++ b/examples/audio_tasks/conf/beamforming.yaml @@ -44,7 +44,6 @@ model: _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram - power: null decoder: _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio diff --git a/examples/audio_tasks/conf/masking.yaml b/examples/audio_tasks/conf/masking.yaml index c667bec53076..68adca116aa5 100644 --- a/examples/audio_tasks/conf/masking.yaml +++ b/examples/audio_tasks/conf/masking.yaml @@ -1,5 +1,3 @@ -# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer. -# name: "masking" model: @@ -44,7 +42,6 @@ model: _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram - power: null decoder: _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio diff --git a/examples/audio_tasks/conf/predictive.yaml b/examples/audio_tasks/conf/predictive.yaml new file mode 100644 index 000000000000..b141ba6fd1ee --- /dev/null +++ b/examples/audio_tasks/conf/predictive.yaml @@ -0,0 +1,130 @@ +name: "predictive_model" + +model: + type: predictive + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + normalize_input: true # normalize the input signal to 0dBFS + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256 + random_offset: true + normalization_signal: input_signal + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + batch_size: 8 + shuffle: false + num_workers: 4 + pin_memory: true + + encoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus + in_channels: 1 # single-channel noisy input + out_channels: 1 # single-channel estimate + num_res_blocks: 3 # increased number of res blocks + pad_time_to: 64 # pad to 64 frames for the time dimension + pad_dimension_to: 0 # no padding in the frequency dimension + + loss: + _target_: nemo.collections.asr.losses.MSELoss # computed in the time domain + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: null + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: False # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_sisdr + mode: max + save_top_k: 5 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/audio_tasks/conf/score_based_generative.yaml b/examples/audio_tasks/conf/score_based_generative.yaml new file mode 100644 index 000000000000..c0b36bd750a2 --- /dev/null +++ b/examples/audio_tasks/conf/score_based_generative.yaml @@ -0,0 +1,149 @@ +name: score_based_generative_model + +model: + type: score_based + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + normalize_input: true + max_utts_evaluation_metrics: 50 # metric calculation needs full inference and is slow, so we limit to first few files + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256 + random_offset: true + normalization_signal: input_signal + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + normalize_input: false # load data as is for validation, the model will normalize it for inference + batch_size: 4 + shuffle: false + num_workers: 4 + pin_memory: true + + encoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel score estimate + conditioned_on_time: true + num_res_blocks: 3 # increased number of res blocks + pad_time_to: 64 # pad to 64 frames for the time dimension + pad_dimension_to: 0 # no padding in the frequency dimension + + sde: + _target_: nemo.collections.asr.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE + stiffness: 1.5 + std_min: 0.05 + std_max: 0.5 + num_steps: 1000 + + sampler: + _target_: nemo.collections.asr.parts.submodules.diffusion.PredictorCorrectorSampler + predictor: reverse_diffusion + corrector: annealed_langevin_dynamics + num_steps: 50 + num_corrector_steps: 1 + snr: 0.5 + + loss: + _target_: nemo.collections.asr.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: null + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_sisdr + mode: max + save_top_k: 5 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/audio_tasks/speech_enhancement.py b/examples/audio_tasks/speech_enhancement.py index 250d212d2a25..33a25c1c107c 100644 --- a/examples/audio_tasks/speech_enhancement.py +++ b/examples/audio_tasks/speech_enhancement.py @@ -26,25 +26,64 @@ PyTorch Lightning Trainer arguments and args of the model and the optimizer can be added or overriden from CLI """ +from enum import Enum + import pytorch_lightning as pl import torch from omegaconf import OmegaConf -from nemo.collections.asr.models import EncMaskDecAudioToAudioModel +from nemo.collections.asr.models.enhancement_models import ( + EncMaskDecAudioToAudioModel, + PredictiveAudioToAudioModel, + ScoreBasedGenerativeAudioToAudioModel, +) from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager +class ModelType(str, Enum): + """Enumeration with the available model types. + """ + + MaskBased = 'mask_based' + Predictive = 'predictive' + ScoreBased = 'score_based' + + +def get_model_class(model_type: ModelType): + """Get model class for a given model type. + """ + if model_type == ModelType.MaskBased: + return EncMaskDecAudioToAudioModel + elif model_type == ModelType.Predictive: + return PredictiveAudioToAudioModel + elif model_type == ModelType.ScoreBased: + return ScoreBasedGenerativeAudioToAudioModel + else: + raise ValueError(f'Unknown model type: {model_type}') + + @hydra_runner(config_path="./conf", config_name="masking") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg, resolve=True)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) - model = EncMaskDecAudioToAudioModel(cfg=cfg.model, trainer=trainer) - # Initialize the weights of the model from another model, if provided via config + # Get model class + model_type = cfg.model.get('type') + if model_type is None: + model_type = ModelType.MaskBased + logging.warning('model_type not found in config. Using default: %s', model_type) + + logging.info('Get class for model type: %s', model_type) + model_class = get_model_class(model_type) + + logging.info('Instantiate model %s', model_class.__name__) + model = model_class(cfg=cfg.model, trainer=trainer) + + logging.info('Initialize the weights of the model from another model, if provided via config') model.maybe_init_from_pretrained_checkpoint(cfg) # Train the model diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml index 8ce009d5458f..dff963590864 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml @@ -49,8 +49,8 @@ model: precision: ${trainer.precision} # specify micro_batch_size, global_batch_size, and model parallelism # gradient accumulation will be done automatically based on data_parallel_size - micro_batch_size: 1 # limited by GPU memory - global_batch_size: 1 # will use more micro batches to reach global batch size + micro_batch_size: 16 # limited by GPU memory + global_batch_size: 16 # will use more micro batches to reach global batch size native_amp_init_scale: 65536.0 # Init scale for grad scaler used at fp16 @@ -97,15 +97,15 @@ model: unet_config: _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel from_pretrained: #/ckpts/nemo-v1-2.ckpt - from_NeMo: True #Must be specified when from pretrained is not None, False means loading unet from HF ckpt + from_NeMo: False #Must be specified when from pretrained is not None, False means loading unet from HF ckpt image_size: 32 # unused in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: - - 4 - - 2 - - 1 + - 4 + - 2 + - 1 num_res_blocks: 2 channel_mult: - 1 @@ -121,6 +121,7 @@ model: use_flash_attention: True unet_precision: fp32 resblock_gn_groups: 32 + use_te_fp8: False first_stage_config: _target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL @@ -140,22 +141,22 @@ model: - 4 - 4 num_res_blocks: 2 - attn_resolutions: [] + attn_resolutions: [ ] dropout: 0.0 lossconfig: target: torch.nn.Identity cond_stage_config: - _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder - restore_from_path: /ckpts/openai.nemo + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + version: openai/clip-vit-large-patch14 device: cuda - freeze: True - layer: "last" - # For compatibility of history version that uses HF clip model - # _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder - # version: openai/clip-vit-large-patch14 - # device: cuda - # max_length: 77 + max_length: 77 + # _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder + # restore_from_path: /ckpts/openai-old.nemo + # device: cuda + # freeze: True + # layer: "last" + # miscellaneous @@ -163,7 +164,7 @@ model: resume_from_checkpoint: null # manually set the checkpoint file to load from apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) - ddp_overlap: True # True for using PyTorch DDP overlap. + ddp_overlap: False # True for using PyTorch DDP overlap. optim: name: fused_adam @@ -191,7 +192,7 @@ model: synthetic_data_length: 10000 train: dataset_path: - - /datasets/coyo/test.pkl + - /datasets/coyo/wdinfo/coyo-700m/wdinfo-selene.pkl augmentations: resize_smallest_side: 512 center_crop_h_w: 512, 512 diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py index f1e5e2872ea7..58e9e6e64470 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py @@ -28,6 +28,9 @@ def model_cfg_modifier(model_cfg): model_cfg.unet_config.use_flash_attention = False model_cfg.unet_config.from_pretrained = None model_cfg.first_stage_config.from_pretrained = None + model_cfg.first_stage_config._target_ = ( + 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL' + ) torch.backends.cuda.matmul.allow_tf32 = True trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml new file mode 100644 index 000000000000..b37d64a325e5 --- /dev/null +++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml @@ -0,0 +1,204 @@ +# An example model that works with this config is "https://huggingface.co/yuvalkirstain/PickScore_v1" +model: + precision: 32 + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 32 # limited by GPU memory + global_batch_size: 32 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_pretrained: null # used in fine-tuning + # multimodal configs + output_dim: 1024 + # As the number of devices used to train increases, so does the space complexity of + # the logit matrix. Using a naïve all-gather scheme, space complexity will be + # `O(n^2)`. Instead, complexity may become effectively linear if the flags + # `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one + # numerical results as the naïve method. + local_loss: False # calculate loss w/ local features @ global (instead of realizing full global @ global matrix) + gather_with_grad: True # enable full distributed gradient for feature gather, set this to False may cause convergence issue + + vision: + precision: 32 + # vision configs + patch_dim: 14 + img_h: 224 + img_w: 224 + image_mean: null + image_std: null + num_channels: 3 + drop_patch_rate: 0.0 + drop_path_rate: 0.0 + global_average_pool: False + output_dim: ${model.output_dim} + class_token_length: 1 + preprocess_layernorm: True # apply layer norm to embedded tokens + + # model architecture + encoder_seq_length: 196 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_parameters + num_layers: 32 + hidden_size: 1280 + ffn_hidden_size: 5120 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: True + activation: gelu + + + + text: + precision: 32 + # text configs + output_dim: ${model.output_dim} + + # model architecture + encoder_seq_length: 77 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_parameters + num_layers: 24 + hidden_size: 1024 + ffn_hidden_size: 4096 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: True + + transformer_engine: False + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + activation: gelu + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'huggingface' + type: 'openai/clip-vit-large-patch14' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + + data: + num_workers: 8 + train: + dataset_path: # List of paths to pkl files or tar files + - /datasets/coyo/test.pkl + validation: # List of paths to pkl files or tar files + dataset_path: + - /datasets/coyo/test.pkl + webdataset: + infinite_sampler: False + local_root_path: /datasets/coyo + + imagenet_val: null # Path to imagenet val set for conducting zero shot evaluation. + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 1e-3 + weight_decay: 0.2 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 2000 + constant_steps: 0 + min_lr: 1e-5 \ No newline at end of file diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml index 40347f317fbb..6517b62010b4 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml @@ -101,6 +101,7 @@ model: position_embedding_strategy: null # used only when weight_tying is True lora_tuning: + variant: "nemo" # can be "nemo" or "canonical" target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) adapter_dim: 32 alpha: ${model.peft.lora_tuning.adapter_dim} diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml index 67d43eb303f4..592eed6c4420 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml @@ -89,6 +89,7 @@ model: position_embedding_strategy: null # used only when weight_tying is True lora_tuning: + variant: "nemo" # can be either "canonical" or "nemo" target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) adapter_dim: 32 adapter_dropout: 0.0 diff --git a/nemo/collections/asr/data/audio_to_audio.py b/nemo/collections/asr/data/audio_to_audio.py index a3c6dd0cc1b3..4f4727239a4b 100644 --- a/nemo/collections/asr/data/audio_to_audio.py +++ b/nemo/collections/asr/data/audio_to_audio.py @@ -130,13 +130,19 @@ class ASRAudioProcessor: sample_rate: sample rate used for all audio signals random_offset: If `True`, offset will be randomized when loading a subsegment from a file. + normalization_signal: Normalize all audio with a factor that ensures the signal + `example[normalization_signal]` in `process` is in range [-1, 1]. + All other audio signals are scaled by the same factor. Default is + `None`, corresponding to no normalization. """ def __init__( - self, sample_rate: float, random_offset: bool, + self, sample_rate: float, random_offset: bool, normalization_signal: Optional[str] = None, eps: float = 1e-8, ): self.sample_rate = sample_rate self.random_offset = random_offset + self.normalization_signal = normalization_signal + self.eps = eps self.sync_setup = None self.async_setup = None @@ -314,7 +320,20 @@ def process_audio(self, audio: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso Returns: An ordered dictionary of signals and their tensors. """ - # Currently, not doing any processing of the loaded signals. + if self.normalization_signal: + # Normalize all audio with a factor that ensures the normalization signal is in range [-1, 1]. + norm_scale = audio[self.normalization_signal].abs().max() + + # Do not normalize embeddings + skip_signals = self.embedding_setup.signals if self.embedding_setup is not None else [] + + # Normalize audio signals + for signal in audio: + if signal not in skip_signals: + # All audio signals are scaled by the same factor. + # This ensures that the relative level between signals is preserved. + audio[signal] = audio[signal] / (norm_scale + self.eps) + return audio def load_sync_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]: @@ -812,6 +831,9 @@ class AudioToTargetDataset(BaseAudioDataset): If `None`, all channels will be loaded. target_channel_selector: Optional, select subset of channels from each input audio file. If `None`, all channels will be loaded. + normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1]. + All audio signals are scaled by the same factor. Supported values are `None` (no normalization), + 'input_signal', 'target_signal'. """ def __init__( @@ -827,6 +849,7 @@ def __init__( max_utts: Optional[int] = None, input_channel_selector: Optional[int] = None, target_channel_selector: Optional[int] = None, + normalization_signal: Optional[str] = None, ): audio_to_manifest_key = { 'input_signal': input_key, @@ -841,7 +864,9 @@ def __init__( max_number=max_utts, ) - audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + audio_processor = ASRAudioProcessor( + sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + ) audio_processor.sync_setup = SignalSetup( signals=['input_signal', 'target_signal'], duration=audio_duration, @@ -932,6 +957,9 @@ class AudioToTargetWithReferenceDataset(BaseAudioDataset): from input and target. reference_duration: Optional, can be used to set a fixed duration of the reference utterance. If `None`, complete audio file will be loaded. + normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1]. + All audio signals are scaled by the same factor. Supported values are `None` (no normalization), + 'input_signal', 'target_signal', 'reference_signal'. """ def __init__( @@ -951,6 +979,7 @@ def __init__( reference_channel_selector: Optional[int] = None, reference_is_synchronized: bool = True, reference_duration: Optional[float] = None, + normalization_signal: Optional[str] = None, ): audio_to_manifest_key = { 'input_signal': input_key, @@ -966,7 +995,9 @@ def __init__( max_number=max_utts, ) - audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + audio_processor = ASRAudioProcessor( + sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + ) if reference_is_synchronized: audio_processor.sync_setup = SignalSetup( @@ -1063,6 +1094,9 @@ class AudioToTargetWithEmbeddingDataset(BaseAudioDataset): If `None`, all channels will be loaded. target_channel_selector: Optional, select subset of channels from each input audio file. If `None`, all channels will be loaded. + normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1]. + All audio signals are scaled by the same factor. Supported values are `None` (no normalization), + 'input_signal', 'target_signal'. """ def __init__( @@ -1079,6 +1113,7 @@ def __init__( max_utts: Optional[int] = None, input_channel_selector: Optional[int] = None, target_channel_selector: Optional[int] = None, + normalization_signal: Optional[str] = None, ): audio_to_manifest_key = { 'input_signal': input_key, @@ -1094,7 +1129,9 @@ def __init__( max_number=max_utts, ) - audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + audio_processor = ASRAudioProcessor( + sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + ) audio_processor.sync_setup = SignalSetup( signals=['input_signal', 'target_signal'], duration=audio_duration, diff --git a/nemo/collections/asr/data/audio_to_audio_dataset.py b/nemo/collections/asr/data/audio_to_audio_dataset.py index b296d64b1f2a..46e47020fda0 100644 --- a/nemo/collections/asr/data/audio_to_audio_dataset.py +++ b/nemo/collections/asr/data/audio_to_audio_dataset.py @@ -36,6 +36,7 @@ def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDat max_utts=config.get('max_utts', 0), input_channel_selector=config.get('input_channel_selector', None), target_channel_selector=config.get('target_channel_selector', None), + normalization_signal=config.get('normalization_signal', None), ) return dataset @@ -65,6 +66,7 @@ def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.A reference_channel_selector=config.get('reference_channel_selector', None), reference_is_synchronized=config.get('reference_is_synchronized', True), reference_duration=config.get('reference_duration', None), + normalization_signal=config.get('normalization_signal', None), ) return dataset @@ -91,5 +93,6 @@ def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.A max_utts=config.get('max_utts', 0), input_channel_selector=config.get('input_channel_selector', None), target_channel_selector=config.get('target_channel_selector', None), + normalization_signal=config.get('normalization_signal', None), ) return dataset diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index 3e50cea1d692..c03f7a48ffe3 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss -from nemo.collections.asr.losses.audio_losses import SDRLoss +from nemo.collections.asr.losses.audio_losses import MSELoss, SDRLoss from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.lattice_losses import LatticeLoss from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss diff --git a/nemo/collections/asr/losses/audio_losses.py b/nemo/collections/asr/losses/audio_losses.py index 62ce4a9f7edd..b0214375a713 100644 --- a/nemo/collections/asr/losses/audio_losses.py +++ b/nemo/collections/asr/losses/audio_losses.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -21,31 +21,33 @@ from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like from nemo.collections.asr.parts.utils.audio_utils import toeplitz from nemo.core.classes import Loss, Typing, typecheck -from nemo.core.neural_types import AudioSignal, LengthsType, LossType, MaskType, NeuralType +from nemo.core.neural_types import AudioSignal, LengthsType, LossType, MaskType, NeuralType, VoidType from nemo.utils import logging -__all__ = ['SDRLoss'] +__all__ = ['SDRLoss', 'MSELoss'] -def temporal_mean( +def calculate_mean( input: torch.Tensor, input_length: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, + dim: Union[int, Tuple[int]] = -1, keepdim: bool = False, eps: float = 1e-10, ) -> torch.Tensor: - """Calculate mean along temporal dimension with optionally + """Calculate mean along dimension `dim` with optionally averaging only over valid samples (based on the input length). Args: - input: Batch of signals, shape (B, C, T) + input: signal, for example (B, C, T) or (B, C, D, T) input_length: Optional, length of each example in the batch, shape (B,) - mask: Optional, temporal mask for each example in the batch, shape (B, T) + mask: Optional, temporal mask for each example in the batch, same shape as the input signal + dim: dimension or dimensions to reduce keepdim: Whether to keep the temporal dimension eps: Regularization to avoid division by zero Returns: - (B, C, 1) if keepdim=True, otherwise (B, C) + Mean over dimensions `dim`. """ if input_length is not None: if mask is not None: @@ -53,17 +55,18 @@ def temporal_mean( 'Argument `input_length` is mutually exclusive with `mask`. Both cannot be used at the same time.' ) # Construct a binary mask - mask = make_seq_mask_like(lengths=input_length, like=input, time_dim=-1, valid_ones=True).squeeze(1) + mask = make_seq_mask_like(lengths=input_length, like=input, time_dim=-1, valid_ones=True) + mask = mask.expand_as(input) if mask is None: # No length information, assume all samples are valid - mean = torch.mean(input, dim=-1, keepdim=keepdim) + mean = torch.mean(input, dim=dim, keepdim=keepdim) else: # Average using temporal mask - mean = mask.unsqueeze(1) * input - mean = torch.sum(mean, axis=-1, keepdim=keepdim) - normalization = torch.sum(mask, axis=-1, keepdim=keepdim) - mean = mean / (normalization.unsqueeze(1) + eps) + mean = mask * input + mean = torch.sum(mean, dim=dim, keepdim=keepdim) + normalization = torch.sum(mask, dim=dim, keepdim=keepdim) + mean = mean / (normalization + eps) return mean @@ -101,16 +104,17 @@ def scale_invariant_target( ) # Construct a binary mask - mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1) + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) - estimate_dot_target = temporal_mean(estimate * target, mask=mask, keepdim=True, eps=eps) - target_pow = temporal_mean(torch.abs(target) ** 2, mask=mask, keepdim=True, eps=eps) + estimate_dot_target = calculate_mean(estimate * target, mask=mask, dim=-1, keepdim=True, eps=eps) + target_pow = calculate_mean(torch.abs(target) ** 2, mask=mask, dim=-1, keepdim=True, eps=eps) scale = estimate_dot_target / (target_pow + eps) target_scaled = scale * target # Mask to keep only the valid samples if mask is not None: - target_scaled = mask.unsqueeze(1) * target_scaled + target_scaled = mask * target_scaled return target_scaled @@ -162,12 +166,13 @@ def convolution_invariant_target( ) # Construct a binary mask - mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1) + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) # Apply a mask, if available if mask is not None: - estimate = mask.unsqueeze(1) * estimate - target = mask.unsqueeze(1) * target + estimate = mask * estimate + target = mask * target # Calculate filtered target input_shape = estimate.shape @@ -207,7 +212,7 @@ def convolution_invariant_target( # Mask to keep only the valid samples if mask is not None: - target_filt = mask.unsqueeze(1) * target_filt + target_filt = mask * target_filt return target_filt @@ -261,11 +266,12 @@ def calculate_sdr_batch( ) # Construct a binary mask - mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1) + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) if remove_mean: - estimate = estimate - temporal_mean(estimate, mask=mask, keepdim=True, eps=eps) - target = target - temporal_mean(target, mask=mask, keepdim=True, eps=eps) + estimate = estimate - calculate_mean(estimate, mask=mask, dim=-1, keepdim=True, eps=eps) + target = target - calculate_mean(target, mask=mask, dim=-1, keepdim=True, eps=eps) if scale_invariant or (convolution_invariant and convolution_filter_length == 1): target = scale_invariant_target(estimate=estimate, target=target, mask=mask, eps=eps) @@ -276,8 +282,8 @@ def calculate_sdr_batch( distortion = estimate - target - target_pow = temporal_mean(torch.abs(target) ** 2, mask=mask, eps=eps) - distortion_pow = temporal_mean(torch.abs(distortion) ** 2, mask=mask, eps=eps) + target_pow = calculate_mean(torch.abs(target) ** 2, mask=mask, dim=-1, eps=eps) + distortion_pow = calculate_mean(torch.abs(distortion) ** 2, mask=mask, dim=-1, eps=eps) if sdr_max is not None: distortion_pow = distortion_pow + 10 ** (-sdr_max / 10) * target_pow @@ -353,7 +359,7 @@ def input_types(self): "estimate": NeuralType(signal_shape, AudioSignal()), "target": NeuralType(signal_shape, AudioSignal()), "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), - "mask": NeuralType(('B', 'T'), MaskType(), optional=True), + "mask": NeuralType(('B', 'C', 'T'), MaskType(), optional=True), } @property @@ -376,10 +382,10 @@ def forward( perform averaging across channels (weighting optional), and apply reduction across the batch. Args: - estimate: Batch of signals, shape (B, T, C) - target: Batch of signals, shape (B, T, C) + estimate: Batch of signals, shape (B, C, T) + target: Batch of signals, shape (B, C, T) input_length: Batch of lengths, shape (B,) - mask: Batch of temporal masks, shape (B, T) + mask: Batch of temporal masks for each channel, shape (B, C, T) Returns: Scalar loss. @@ -410,3 +416,161 @@ def forward( sdr = self.reduce(sdr) return -sdr + + +def calculate_mse_batch( + estimate: torch.Tensor, + target: torch.Tensor, + input_length: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Calculate MSE per channel. + + MSE = ||estimate - target||_2^2 / input_length + + Args: + estimate: estimated signal, shape (B, C, T) or (B, C, D, T) + target: target signal, shape (B, C, T) or (B, C, D, T) + input_length: Optional, length of valid samples, shape (B,) + mask: Optional, temporal mask, same shape as signals + + Returns: + MSE for each channel, shape (B, C) + """ + assert ( + estimate.shape == target.shape + ), f'Estimate shape ({estimate.shape}) not matching target shape ({target.shape})' + + if input_length is not None: + if mask is not None: + raise RuntimeError( + 'Argument `input_length` is mutually exclusive with `mask`. Both cannot be used at the same time.' + ) + + # Construct a binary mask + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) + + # error + err = estimate - target + + # dimensions for averaging + if estimate.ndim == 3: + # average across time + dim = -1 + elif estimate.ndim == 4: + # average across time and features + dim = (-2, -1) + else: + raise RuntimeError(f'Unexpected dimension of the input: {estimate.shape}') + + # calculate masked mean + mse = calculate_mean(torch.abs(err) ** 2, mask=mask, dim=dim) + + return mse + + +class MSELoss(Loss, Typing): + """ + Computes MSE loss with weighted average across channels. + + Args: + weight: weight for loss of each output channel, used for averaging the loss across channels. Defaults to `None` (averaging). + reduction: batch reduction. Defaults to `mean` over the batch. + ndim: Number of dimensions for the input signal + """ + + def __init__( + self, weight: Optional[List[float]] = None, reduction: str = 'mean', ndim: int = 3, + ): + super().__init__() + + # weight buffer + if weight is not None: + if any([w <= 0 for w in weight]): + raise ValueError(f'Weight must be positive! Current value: {weight}') + elif not np.isclose(sum(weight), 1, atol=1e-6): + raise ValueError(f'Weight should add to one, current weight: {weight}') + weight = torch.tensor(weight).reshape(1, -1) + logging.info(f'Channel weight set to %s', weight) + self.register_buffer('weight', weight) + self.weight: Optional[Tensor] + + # Batch reduction + self.reduction = reduction + if reduction == 'mean': + self.reduce = torch.mean + else: + raise ValueError(f'Unexpected reduction mode {reduction}.') + + # Input dimension + self.ndim = ndim + + if self.ndim == 3: + # Time-domain input + self.signal_shape = ('B', 'C', 'T') + elif self.ndim == 4: + # Spectral-domain input + self.signal_shape = ('B', 'C', 'D', 'T') + else: + raise ValueError(f'Unexpected input dimension: {self.ndim}') + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tweight: %s', self.weight) + logging.debug('\treduction: %s', self.reduction) + logging.debug('\tndim: %s', self.ndim) + logging.debug('\tsignal_shape: %s', self.signal_shape) + + @property + def input_types(self): + """Input types definitions for SDRLoss. + """ + return { + "estimate": NeuralType(self.signal_shape, VoidType()), + "target": NeuralType(self.signal_shape, VoidType()), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "mask": NeuralType(self.signal_shape, MaskType(), optional=True), + } + + @property + def output_types(self): + """Output types definitions for SDRLoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward( + self, + estimate: torch.Tensor, + target: torch.Tensor, + input_length: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """For input batch of multi-channel signals, calculate SDR between estimate and target for each channel, + perform averaging across channels (weighting optional), and apply reduction across the batch. + + Args: + estimate: Estimate of the target signal + target: Target signal + input_length: Length of each example in the batch + mask: Mask for each signal + + Returns: + Scalar loss. + """ + mse = calculate_mse_batch(estimate=estimate, target=target, input_length=input_length, mask=mask,) + + # channel averaging + if self.weight is None: + mse = torch.mean(mse, dim=1) + else: + # weighting across channels + mse = mse * self.weight + mse = torch.sum(mse, dim=1) + + # reduction + mse = self.reduce(mse) + + return mse diff --git a/nemo/collections/asr/metrics/audio.py b/nemo/collections/asr/metrics/audio.py index 5e8c2915e3fa..db63ac19c098 100644 --- a/nemo/collections/asr/metrics/audio.py +++ b/nemo/collections/asr/metrics/audio.py @@ -57,6 +57,7 @@ class AudioMetricWrapper(Metric): """ full_state_update: bool = False + num_examples: torch.Tensor def __init__( self, metric: Metric, channel: Optional[int] = None, metric_using_batch_averaging: Optional[bool] = None @@ -74,6 +75,7 @@ def __init__( self._metric = metric self._channel = channel + self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum') logging.debug('Setup metric %s, channel %s', metric, str(channel)) def _select_channel(self, preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -144,6 +146,8 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, input_length: Option for b_preds, b_target in self._trim_inputs(preds=preds, target=target, input_length=input_length): self._metric.update(preds=b_preds, target=b_target) + self.num_examples += preds.size(0) + def compute(self) -> torch.Tensor: """Compute the underlying metric. """ @@ -179,6 +183,9 @@ def forward( def reset(self) -> None: """Reset the underlying metric. """ + # reset the internal states + super().reset() + # reset the underlying metric self._metric.reset() def __repr__(self) -> str: diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 019c57f9c4e3..23c759afc80d 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -23,7 +23,11 @@ from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel -from nemo.collections.asr.models.enhancement_models import EncMaskDecAudioToAudioModel +from nemo.collections.asr.models.enhancement_models import ( + EncMaskDecAudioToAudioModel, + PredictiveAudioToAudioModel, + ScoreBasedGenerativeAudioToAudioModel, +) from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel from nemo.collections.asr.models.k2_sequence_models import ( diff --git a/nemo/collections/asr/models/audio_to_audio_model.py b/nemo/collections/asr/models/audio_to_audio_model.py index 49364843e8b8..094dbc38b72a 100644 --- a/nemo/collections/asr/models/audio_to_audio_model.py +++ b/nemo/collections/asr/models/audio_to_audio_model.py @@ -12,15 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import os +import tempfile from abc import ABC, abstractmethod -from typing import List, Union +from typing import Dict, List, Optional, Union import hydra +import librosa +import soundfile as sf import torch from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer +from tqdm import tqdm +from nemo.collections.asr.data import audio_to_audio_dataset +from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset +from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config from nemo.collections.asr.metrics.audio import AudioMetricWrapper +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.core.classes import ModelPT from nemo.utils import logging, model_utils @@ -158,23 +169,384 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): return self.multi_evaluation_epoch_end(outputs, dataloader_idx, 'test') - @abstractmethod + @torch.no_grad() def process( - self, paths2audio_files: List[str], output_dir: str, batch_size: int = 4 - ) -> List[Union[str, List[str]]]: + self, + paths2audio_files: List[str], + output_dir: str, + batch_size: int = 1, + num_workers: Optional[int] = None, + input_channel_selector: Optional[ChannelSelectorType] = None, + ) -> List[str]: + """ + Process audio files provided in paths2audio_files. + Processed signals will be saved in output_dir. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + output_dir: + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + num_workers: Number of workers for the dataloader + input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + + Returns: + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + # Output + paths2processed_files = [] + + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + + try: + # Switch model to evaluation mode + self.eval() + # Freeze weights + self.freeze() + + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + + # Processing + with tempfile.TemporaryDirectory() as tmpdir: + # Save temporary manifest + temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') + with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: + for audio_file in paths2audio_files: + entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} + fp.write(json.dumps(entry) + '\n') + + config = { + 'manifest_filepath': temporary_manifest_filepath, + 'input_key': 'input_filepath', + 'input_channel_selector': input_channel_selector, + 'batch_size': min(batch_size, len(paths2audio_files)), + 'num_workers': num_workers, + } + + # Create output dir if necessary + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + # DataLoader for the input files + temporary_dataloader = self._setup_process_dataloader(config) + + # Indexing of the original files, used to form the output file name + file_idx = 0 + + # Process batches + for test_batch in tqdm(temporary_dataloader, desc="Processing"): + input_signal = test_batch[0] + input_length = test_batch[1] + + # Expand channel dimension, if necessary + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = input_signal.unsqueeze(1) + + processed_batch, _ = self.forward( + input_signal=input_signal.to(device), input_length=input_length.to(device) + ) + + for example_idx in range(processed_batch.size(0)): + # This assumes the data loader is not shuffling files + file_name = os.path.basename(paths2audio_files[file_idx]) + # Prepare output file + output_file = os.path.join(output_dir, f'processed_{file_name}') + # Crop the output signal to the actual length + output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy() + # Write audio + sf.write(output_file, output_signal.T, self.sample_rate, 'float') + # Update the file counter + file_idx += 1 + # Save processed file + paths2processed_files.append(output_file) + + del test_batch + del processed_batch + + finally: + # set mode back to its original value + self.train(mode=mode) + if mode is True: + self.unfreeze() + logging.set_verbosity(logging_level) + + return paths2processed_files + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + + if config.get("use_lhotse", False): + return get_lhotse_dataloader_from_config( + config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseAudioToTargetDataset() + ) + + is_concat = config.get('is_concat', False) + if is_concat: + raise NotImplementedError('Concat not implemented') + + # TODO: Consider moving `inject` from `audio_to_text_dataset` to a utility module? + # Automatically inject args from model config to dataloader config + inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') + + # Instantiate tarred dataset loader or normal dataset loader + if config.get('is_tarred', False): + raise NotImplementedError('Tarred datasets not supported') + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config) + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + # support datasets that are lists of entries + collate_fn = dataset.datasets[0].collate_fn + else: + # support datasets that are lists of lists + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of a training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + raise NotImplementedError('Tarred datasets not supported') + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of a validation dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of a test dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """Prepare a dataloader for processing files. + + Args: + config: A python dictionary which contains the following keys: + manifest_filepath: path to a manifest file + input_key: key with audio filepaths in the manifest + input_channel_selector: Optional, used to select a subset of channels from input audio files + batch_size: batch size for the dataloader + num_workers: number of workers for the dataloader + + Returns: + A pytorch DataLoader for the given manifest filepath. + """ + dl_config = { + 'manifest_filepath': config['manifest_filepath'], + 'sample_rate': self.sample_rate, + 'input_key': config['input_key'], + 'input_channel_selector': config.get('input_channel_selector', None), + 'target_key': None, + 'target_channel_selector': None, + 'batch_size': config['batch_size'], + 'shuffle': False, + 'num_workers': config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)), + 'pin_memory': True, + } + + temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_dataloader + + @staticmethod + def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor: + """Trim or pad the output to match the batch length. + + Args: + input: tensor with shape (B, C, T) + batch_length: int + + Returns: + Tensor with shape (B, C, T), where T matches the + batch length. + """ + input_length = input.size(-1) + pad_length = batch_length - input_length + pad = (0, pad_length) + # pad with zeros or crop + return torch.nn.functional.pad(input, pad, 'constant', 0) + + @torch.no_grad() + def process( + self, + paths2audio_files: List[str], + output_dir: str, + batch_size: int = 1, + num_workers: Optional[int] = None, + input_channel_selector: Optional[ChannelSelectorType] = None, + ) -> List[str]: """ Takes paths to audio files and returns a list of paths to processed audios. Args: paths2audio_files: paths to audio files to be processed - output_dir: directory to save processed files - batch_size: batch size for inference + output_dir: directory to save the processed files + batch_size: (int) batch size to use during inference. + num_workers: Number of workers for the dataloader + input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. + If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Returns: Paths to processed audio signals. """ - pass + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + # Output + paths2processed_files = [] + + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + + try: + # Switch model to evaluation mode + self.eval() + # Freeze weights + self.freeze() + + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + + # Processing + with tempfile.TemporaryDirectory() as tmpdir: + # Save temporary manifest + temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') + with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: + for audio_file in paths2audio_files: + entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} + fp.write(json.dumps(entry) + '\n') + + config = { + 'manifest_filepath': temporary_manifest_filepath, + 'input_key': 'input_filepath', + 'input_channel_selector': input_channel_selector, + 'batch_size': min(batch_size, len(paths2audio_files)), + 'num_workers': num_workers, + } + + # Create output dir if necessary + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + # DataLoader for the input files + temporary_dataloader = self._setup_process_dataloader(config) + + # Indexing of the original files, used to form the output file name + file_idx = 0 + + # Process batches + for test_batch in tqdm(temporary_dataloader, desc="Processing"): + input_signal = test_batch[0] + input_length = test_batch[1] + + # Expand channel dimension, if necessary + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = input_signal.unsqueeze(1) + + processed_batch, _ = self.forward( + input_signal=input_signal.to(device), input_length=input_length.to(device) + ) + + for example_idx in range(processed_batch.size(0)): + # This assumes the data loader is not shuffling files + file_name = os.path.basename(paths2audio_files[file_idx]) + # Prepare output file + output_file = os.path.join(output_dir, f'processed_{file_name}') + # Crop the output signal to the actual length + output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy() + # Write audio + sf.write(output_file, output_signal.T, self.sample_rate, 'float') + # Update the file counter + file_idx += 1 + # Save processed file + paths2processed_files.append(output_file) + + del test_batch + del processed_batch + + finally: + # set mode back to its original value + self.train(mode=mode) + if mode is True: + self.unfreeze() + logging.set_verbosity(logging_level) + + return paths2processed_files @classmethod def list_available_models(cls) -> 'List[PretrainedModelInfo]': diff --git a/nemo/collections/asr/models/enhancement_models.py b/nemo/collections/asr/models/enhancement_models.py index b80c357364aa..b765ae0fddad 100644 --- a/nemo/collections/asr/models/enhancement_models.py +++ b/nemo/collections/asr/models/enhancement_models.py @@ -16,6 +16,8 @@ import tempfile from typing import Dict, List, Optional, Union +import einops +import hydra import librosa import soundfile as sf import torch @@ -23,17 +25,13 @@ from pytorch_lightning import Trainer from tqdm import tqdm -from nemo.collections.asr.data import audio_to_audio_dataset -from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset -from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config + from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType -from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.core.classes.common import PretrainedModelInfo, typecheck -from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType from nemo.utils import logging -__all__ = ['EncMaskDecAudioToAudioModel'] +__all__ = ['EncMaskDecAudioToAudioModel', 'ScoreBasedGenerativeAudioToAudioModel', 'PredictiveAudioToAudioModel'] class EncMaskDecAudioToAudioModel(AudioToAudioModel): @@ -69,10 +67,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): logging.debug('Mixture consistency not used') self.mixture_consistency = None - # Future enhancement: - # If subclasses need to modify the config before calling super() - # Check ASRBPE* classes do with their mixin - # Setup augmentation if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None: logging.debug('Using channel augmentation') @@ -84,254 +78,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup optional Optimization flags self.setup_optimization_flags() - @torch.no_grad() - def process( - self, - paths2audio_files: List[str], - output_dir: str, - batch_size: int = 1, - num_workers: Optional[int] = None, - input_channel_selector: Optional[ChannelSelectorType] = None, - ) -> List[str]: - """ - Process audio files provided in paths2audio_files. - Processed signals will be saved in output_dir. - - Args: - paths2audio_files: (a list) of paths to audio files. \ - Recommended length per file is between 5 and 25 seconds. \ - But it is possible to pass a few hours long file if enough GPU memory is available. - output_dir: - batch_size: (int) batch size to use during inference. - Bigger will result in better throughput performance but would use more memory. - num_workers: Number of workers for the dataloader - input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. - - Returns: - """ - if paths2audio_files is None or len(paths2audio_files) == 0: - return {} - - if num_workers is None: - num_workers = min(batch_size, os.cpu_count() - 1) - - # Output - paths2processed_files = [] - - # Model's mode and device - mode = self.training - device = next(self.parameters()).device - - try: - # Switch model to evaluation mode - self.eval() - # Freeze weights - self.freeze() - - logging_level = logging.get_verbosity() - logging.set_verbosity(logging.WARNING) - - # Processing - with tempfile.TemporaryDirectory() as tmpdir: - # Save temporary manifest - temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') - with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: - for audio_file in paths2audio_files: - entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} - fp.write(json.dumps(entry) + '\n') - - config = { - 'manifest_filepath': temporary_manifest_filepath, - 'input_key': 'input_filepath', - 'input_channel_selector': input_channel_selector, - 'batch_size': min(batch_size, len(paths2audio_files)), - 'num_workers': num_workers, - } - - # Create output dir if necessary - if not os.path.isdir(output_dir): - os.makedirs(output_dir) - - # DataLoader for the input files - temporary_dataloader = self._setup_process_dataloader(config) - - # Indexing of the original files, used to form the output file name - file_idx = 0 - - # Process batches - for test_batch in tqdm(temporary_dataloader, desc="Processing"): - input_signal = test_batch[0] - input_length = test_batch[1] - - # Expand channel dimension, if necessary - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = input_signal.unsqueeze(1) - - processed_batch, _ = self.forward( - input_signal=input_signal.to(device), input_length=input_length.to(device) - ) - - for example_idx in range(processed_batch.size(0)): - # This assumes the data loader is not shuffling files - file_name = os.path.basename(paths2audio_files[file_idx]) - # Prepare output file - output_file = os.path.join(output_dir, f'processed_{file_name}') - # Crop the output signal to the actual length - output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy() - # Write audio - sf.write(output_file, output_signal.T, self.sample_rate, 'float') - # Update the file counter - file_idx += 1 - # Save processed file - paths2processed_files.append(output_file) - - del test_batch - del processed_batch - - finally: - # set mode back to its original value - self.train(mode=mode) - if mode is True: - self.unfreeze() - logging.set_verbosity(logging_level) - - return paths2processed_files - - def _setup_dataloader_from_config(self, config: Optional[Dict]): - - if config.get("use_lhotse", False): - return get_lhotse_dataloader_from_config( - config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseAudioToTargetDataset() - ) - - is_concat = config.get('is_concat', False) - if is_concat: - raise NotImplementedError('Concat not implemented') - - # TODO: Consider moving `inject` from `audio_to_text_dataset` to a utility module? - # Automatically inject args from model config to dataloader config - inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') - - # Instantiate tarred dataset loader or normal dataset loader - if config.get('is_tarred', False): - raise NotImplementedError('Tarred datasets not supported') - - if 'manifest_filepath' in config and config['manifest_filepath'] is None: - logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") - return None - - dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config) - - if hasattr(dataset, 'collate_fn'): - collate_fn = dataset.collate_fn - elif hasattr(dataset.datasets[0], 'collate_fn'): - # support datasets that are lists of entries - collate_fn = dataset.datasets[0].collate_fn - else: - # support datasets that are lists of lists - collate_fn = dataset.datasets[0].datasets[0].collate_fn - - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config['batch_size'], - collate_fn=collate_fn, - drop_last=config.get('drop_last', False), - shuffle=config['shuffle'], - num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), - ) - - def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): - """ - Sets up the training data loader via a Dict-like object. - - Args: - train_data_config: A config that contains the information regarding construction - of a training dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` - """ - if 'shuffle' not in train_data_config: - train_data_config['shuffle'] = True - - # preserve config - self._update_dataset_config(dataset_name='train', config=train_data_config) - - self._train_dl = self._setup_dataloader_from_config(config=train_data_config) - - if 'is_tarred' in train_data_config and train_data_config['is_tarred']: - raise NotImplementedError('Tarred datasets not supported') - - def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): - """ - Sets up the validation data loader via a Dict-like object. - - Args: - val_data_config: A config that contains the information regarding construction - of a validation dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` - """ - if 'shuffle' not in val_data_config: - val_data_config['shuffle'] = False - - # preserve config - self._update_dataset_config(dataset_name='validation', config=val_data_config) - - self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) - - def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): - """ - Sets up the test data loader via a Dict-like object. - - Args: - test_data_config: A config that contains the information regarding construction - of a test dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` - """ - if 'shuffle' not in test_data_config: - test_data_config['shuffle'] = False - - # preserve config - self._update_dataset_config(dataset_name='test', config=test_data_config) - - self._test_dl = self._setup_dataloader_from_config(config=test_data_config) - - def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': - """Prepare a dataloader for processing files. - - Args: - config: A python dictionary which contains the following keys: - manifest_filepath: path to a manifest file - input_key: key with audio filepaths in the manifest - input_channel_selector: Optional, used to select a subset of channels from input audio files - batch_size: batch size for the dataloader - num_workers: number of workers for the dataloader - - Returns: - A pytorch DataLoader for the given manifest filepath. - """ - dl_config = { - 'manifest_filepath': config['manifest_filepath'], - 'sample_rate': self.sample_rate, - 'input_key': config['input_key'], - 'input_channel_selector': config.get('input_channel_selector', None), - 'target_key': None, - 'target_channel_selector': None, - 'batch_size': config['batch_size'], - 'shuffle': False, - 'num_workers': config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)), - 'pin_memory': True, - } - - temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config)) - return temporary_dataloader - @property def input_types(self) -> Dict[str, NeuralType]: return { @@ -350,23 +96,6 @@ def output_types(self) -> Dict[str, NeuralType]: "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), } - def match_batch_length(self, input: torch.Tensor, batch_length: int): - """Trim or pad the output to match the batch length. - - Args: - input: tensor with shape (B, C, T) - batch_length: int - - Returns: - Tensor with shape (B, C, T), where T matches the - batch length. - """ - input_length = input.size(-1) - pad_length = batch_length - input_length - pad = (0, pad_length) - # pad with zeros or crop - return torch.nn.functional.pad(input, pad, 'constant', 0) - @typecheck() def forward(self, input_signal, input_length=None): """ @@ -380,6 +109,7 @@ def forward(self, input_signal, input_length=None): sequences. Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. """ batch_length = input_signal.size(-1) @@ -414,12 +144,11 @@ def training_step(self, batch, batch_idx): else: input_signal, input_length, target_signal, _ = batch - # Expand channel dimension, if necessary # For consistency, the model uses multi-channel format, even if the channel dimension is 1 if input_signal.ndim == 2: - input_signal = input_signal.unsqueeze(1) + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') if target_signal.ndim == 2: - target_signal = target_signal.unsqueeze(1) + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') # Apply channel augmentation if self.training and self.channel_augmentation is not None: @@ -449,12 +178,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = else: input_signal, input_length, target_signal, _ = batch - # Expand channel dimension, if necessary # For consistency, the model uses multi-channel format, even if the channel dimension is 1 if input_signal.ndim == 2: - input_signal = input_signal.unsqueeze(1) + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') if target_signal.ndim == 2: - target_signal = target_signal.unsqueeze(1) + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') # Process input processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) @@ -485,3 +213,406 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]: results = [] return results + + +class PredictiveAudioToAudioModel(AudioToAudioModel): + """This models aims to directly estimate the coefficients + in the encoded domain by applying a neural model. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate + + # Setup processing modules + self.encoder = self.from_config_dict(self._cfg.encoder) + self.decoder = self.from_config_dict(self._cfg.decoder) + + # Neural estimator + self.estimator = self.from_config_dict(self._cfg.estimator) + + # Normalization + self.normalize_input = self._cfg.get('normalize_input', False) + + # Term added to the denominator to improve numerical stability + self.eps = self._cfg.get('eps', 1e-8) + + # Setup optional Optimization flags + self.setup_optimization_flags() + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tnormalize_input: %s', self.normalize_input) + logging.debug('\teps: %s', self.eps) + + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return { + "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @typecheck() + def forward(self, input_signal, input_length=None): + """Forward pass of the model. + + Args: + input_signal: time-domain signal + input_length: valid length of each example in the batch + + Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. + """ + batch_length = input_signal.size(-1) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + + # Encoder + encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) + + # Backbone + estimated, estimated_length = self.estimator(input=encoded, input_length=encoded_length) + + # Decoder + output, output_length = self.decoder(input=estimated, input_length=estimated_length) + + if self.normalize_input: + # rescale to the original scale + output = output * norm_scale + + # Trim or pad the estimated signal to match input length + output = self.match_batch_length(input=output, batch_length=batch_length) + return output, output_length + + # PTL-specific methods + def training_step(self, batch, batch_idx): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Estimate the signal + output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) + + # Calculate the loss + loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length) + + # Logs + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return loss + + def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Estimate the signal + output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) + + # Prepare output + loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length) + + # Update metrics + if hasattr(self, 'metrics') and tag in self.metrics: + # Update metrics for this (tag, dataloader_idx) + for name, metric in self.metrics[tag][dataloader_idx].items(): + metric.update(preds=output_signal, target=target_signal, input_length=input_length) + + # Log global step + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return {f'{tag}_loss': loss} + + +class ScoreBasedGenerativeAudioToAudioModel(AudioToAudioModel): + """This models is using a score-based diffusion process to generate + an encoded representation of the enhanced signal. + + The model consists of the following blocks: + - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform) + - estimator: neural model, estimates a score for the diffusion process + - sde: stochastic differential equation (SDE) defining the forward and reverse diffusion process + - sampler: sampler for the reverse diffusion process, estimates coefficients of the target signal + - decoder: transforms sampler output into the time domain (synthesis transform) + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate + + # Setup processing modules + self.encoder = self.from_config_dict(self._cfg.encoder) + self.decoder = self.from_config_dict(self._cfg.decoder) + + # Neural score estimator + self.estimator = self.from_config_dict(self._cfg.estimator) + + # SDE + self.sde = self.from_config_dict(self._cfg.sde) + + # Sampler + if 'sde' in self._cfg.sampler: + raise ValueError('SDE should be defined in the model config, not in the sampler config') + if 'score_estimator' in self._cfg.sampler: + raise ValueError('Score estimator should be defined in the model config, not in the sampler config') + + self.sampler = hydra.utils.instantiate(self._cfg.sampler, sde=self.sde, score_estimator=self.estimator) + + # Normalization + self.normalize_input = self._cfg.get('normalize_input', False) + + # Metric evaluation + self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics') + + if self.max_utts_evaluation_metrics is not None: + logging.warning( + 'Metrics will be evaluated on first %d examples of the evaluation datasets.', + self.max_utts_evaluation_metrics, + ) + + # Term added to the denominator to improve numerical stability + self.eps = self._cfg.get('eps', 1e-8) + + # Setup optional Optimization flags + self.setup_optimization_flags() + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tnormalize_input: %s', self.normalize_input) + logging.debug('\teps: %s', self.eps) + + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return { + "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @typecheck() + @torch.inference_mode() + def forward(self, input_signal, input_length=None): + """Forward pass of the model. + + Forward pass of the model aplies the following steps: + - encoder to obtain the encoded representation of the input signal + - sampler to generate the estimated coefficients of the target signal + - decoder to transform the sampler output into the time domain + + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + + Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. + """ + batch_length = input_signal.size(-1) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + + # Encoder + encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) + + # Sampler + generated, generated_length = self.sampler( + prior_mean=encoded, score_condition=encoded, state_length=encoded_length + ) + + # Decoder + output, output_length = self.decoder(input=generated, input_length=generated_length) + + if self.normalize_input: + # rescale to the original scale + output = output * norm_scale + + # Trim or pad the estimated signal to match input length + output = self.match_batch_length(input=output, batch_length=batch_length) + return output, output_length + + @typecheck( + input_types={ + "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_length": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"loss": NeuralType(None, LossType()),}, + ) + def _step(self, target_signal, input_signal, input_length=None): + """Randomly generate a time step for each example in the batch, estimate + the score and calculate the loss value. + + Note that this step does not include sampler. + """ + batch_size = target_signal.size(0) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + # scale the target signal + target_signal = target_signal / (norm_scale + self.eps) + + # Apply encoder to both target and the input + input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length) + target_enc, _ = self.encoder(input=target_signal, input_length=input_length) + + # Generate random time steps + sde_time = self.sde.generate_time(size=batch_size, device=input_enc.device) + + # Get the mean and the variance of the perturbation kernel + pk_mean, pk_std = self.sde.perturb_kernel_params(state=target_enc, prior_mean=input_enc, time=sde_time) + + # Generate a random sample from a standard normal distribution + z_norm = torch.randn_like(input_enc) + + # Prepare perturbed data + perturbed_enc = pk_mean + pk_std * z_norm + + # Score is conditioned on the perturbed data and the input + estimator_input = torch.cat([perturbed_enc, input_enc], dim=-3) + + # Estimate the score using the neural estimator + # SDE time is used to inform the estimator about the current time step + # Note: + # - some implementations use `score = -self._raw_dnn_output(x, t, y)` + # - this seems to be unimportant, and is an artifact of transfering code from the original Song's repo + score_est, score_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=sde_time) + + # Score loss weighting as in Section 4.2 in http://arxiv.org/abs/1907.05600 + score_est = score_est * pk_std + score_ref = -z_norm + + # Score matching loss on the normalized scores + loss = self.loss(estimate=score_est, target=score_ref, input_length=score_len) + + return loss + + # PTL-specific methods + def training_step(self, batch, batch_idx): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Calculate the loss + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + + # Logs + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return loss + + def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Calculate loss + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + + # Update metrics + update_metrics = False + if self.max_utts_evaluation_metrics is None: + # Always update if max is not configured + update_metrics = True + # Number of examples to process + num_examples = input_signal.size(0) # batch size + else: + # Check how many examples have been used for metric calculation + first_metric_name = next(iter(self.metrics[tag][dataloader_idx])) + num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples + # Update metrics if some examples were not processed + update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics + # Number of examples to process + num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0)) + + if update_metrics: + # Generate output signal + output_signal, _ = self.forward( + input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples] + ) + + # Update metrics + if hasattr(self, 'metrics') and tag in self.metrics: + # Update metrics for this (tag, dataloader_idx) + for name, metric in self.metrics[tag][dataloader_idx].items(): + metric.update( + preds=output_signal, + target=target_signal[:num_examples, ...], + input_length=input_length[:num_examples], + ) + + # Log global step + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return {f'{tag}_loss': loss} diff --git a/nemo/collections/asr/modules/audio_modules.py b/nemo/collections/asr/modules/audio_modules.py index 82cfbefeb8d9..67a923099cde 100644 --- a/nemo/collections/asr/modules/audio_modules.py +++ b/nemo/collections/asr/modules/audio_modules.py @@ -17,7 +17,7 @@ import numpy as np import torch -from nemo.collections.asr.losses.audio_losses import temporal_mean +from nemo.collections.asr.losses.audio_losses import calculate_mean from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like from nemo.collections.asr.parts.submodules.multichannel_modules import ( @@ -39,6 +39,7 @@ 'MaskReferenceChannel', 'MaskBasedBeamformer', 'MaskBasedDereverbWPE', + 'MixtureConsistencyProjection', ] @@ -158,7 +159,7 @@ def get_mean_time_channel(input: torch.Tensor, input_length: Optional[torch.Tens mean = torch.mean(input, dim=(-1, -3), keepdim=True) else: # temporal mean - mean = temporal_mean(input, input_length, keepdim=True) + mean = calculate_mean(input, input_length, dim=-1, keepdim=True) # channel mean mean = torch.mean(mean, dim=-3, keepdim=True) @@ -186,7 +187,7 @@ def get_mean_std_time_channel( mean = cls.get_mean_time_channel(input, input_length) std = (input - mean).pow(2) # temporal mean - std = temporal_mean(std, input_length, keepdim=True) + std = calculate_mean(std, input_length, dim=-1, keepdim=True) # channel mean std = torch.mean(std, dim=-3, keepdim=True) # final value diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index cc5312403255..643bc4a69d69 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -709,9 +709,11 @@ class AudioToSpectrogram(NeuralModule): hop_length: length of hops/shifts of the sliding window power: exponent for magnitude spectrogram. Default `None` will return a complex-valued spectrogram + magnitude_power: Transform magnitude of the spectrogram as x^magnitude_power. + scale: Positive scaling of the spectrogram. """ - def __init__(self, fft_length: int, hop_length: int, power: Optional[float] = None): + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): if not HAVE_TORCHAUDIO: logging.error('Could not import torchaudio. Some features might not work.') @@ -726,12 +728,26 @@ def __init__(self, fft_length: int, hop_length: int, power: Optional[float] = No raise ValueError(f'fft_length = {fft_length} must be divisible by 2') self.stft = torchaudio.transforms.Spectrogram( - n_fft=fft_length, hop_length=hop_length, power=power, pad_mode='constant' + n_fft=fft_length, hop_length=hop_length, power=None, pad_mode='constant' ) # number of subbands self.F = fft_length // 2 + 1 + if magnitude_power <= 0: + raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') + self.magnitude_power = magnitude_power + + if scale <= 0: + raise ValueError(f'Scale needs to be positive: current value {scale}') + self.scale = scale + + logging.debug('Initialized %s with:', self.__class__.__name__) + logging.debug('\tfft_length: %s', fft_length) + logging.debug('\thop_length: %s', hop_length) + logging.debug('\tmagnitude_power: %s', magnitude_power) + logging.debug('\tscale: %s', scale) + @property def num_subbands(self) -> int: return self.F @@ -776,6 +792,14 @@ def forward( with torch.cuda.amp.autocast(enabled=False): output = self.stft(input.float()) + if self.magnitude_power != 1: + # apply power on the magnitude + output = torch.pow(output.abs(), self.magnitude_power) * torch.exp(1j * output.angle()) + + if self.scale != 1: + # apply scaling of the coefficients + output = self.scale * output + if input_length is not None: # Mask padded frames output_length = self.get_output_length(input_length=input_length) @@ -810,11 +834,11 @@ class SpectrogramToAudio(NeuralModule): Args: fft_length: length of FFT hop_length: length of hops/shifts of the sliding window - power: exponent for magnitude spectrogram. Default `None` will - return a complex-valued spectrogram + magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power). + scale: Spectrogram will be scaled with 1/scale before the inverse transform. """ - def __init__(self, fft_length: int, hop_length: int): + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): if not HAVE_TORCHAUDIO: logging.error('Could not import torchaudio. Some features might not work.') @@ -834,6 +858,20 @@ def __init__(self, fft_length: int, hop_length: int): self.F = fft_length // 2 + 1 + if magnitude_power <= 0: + raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') + self.magnitude_power = magnitude_power + + if scale <= 0: + raise ValueError(f'Scale needs to be positive: current value {scale}') + self.scale = scale + + logging.debug('Initialized %s with:', self.__class__.__name__) + logging.debug('\tfft_length: %s', fft_length) + logging.debug('\thop_length: %s', hop_length) + logging.debug('\tmagnitude_power: %s', magnitude_power) + logging.debug('\tscale: %s', scale) + @property def num_subbands(self) -> int: return self.F @@ -875,7 +913,16 @@ def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = No # iSTFT output (B, C, T) with torch.cuda.amp.autocast(enabled=False): - output = self.istft(input.cfloat()) + output = input.cfloat() + + if self.scale != 1: + # apply 1/scale on the coefficients + output = output / self.scale + + if self.magnitude_power != 1: + # apply 1/power on the magnitude + output = torch.pow(output.abs(), 1 / self.magnitude_power) * torch.exp(1j * output.angle()) + output = self.istft(output) if input_length is not None: # Mask padded samples diff --git a/nemo/collections/asr/parts/submodules/diffusion.py b/nemo/collections/asr/parts/submodules/diffusion.py new file mode 100644 index 000000000000..db3d30f49701 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/diffusion.py @@ -0,0 +1,1310 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod +from typing import Dict, Optional, Sequence, Tuple, Type + +import einops +import einops.layers.torch +import numpy as np +import torch +import torch.nn.functional as F + +from nemo.collections.common.parts.utils import activation_registry +from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType, VoidType +from nemo.utils import logging + +__all__ = [ + 'OrnsteinUhlenbeckVarianceExplodingSDE', + 'SpectrogramNoiseConditionalScoreNetworkPlusPlus', + 'NoiseConditionalScoreNetworkPlusPlus', + 'PredictorCorrectorSampler', +] + + +class StochasticDifferentialEquation(NeuralModule, ABC): + """Base class for stochastic differential equations. + """ + + def __init__(self, time_min: float, time_max: float, num_steps: int): + super().__init__() + + # min and max time + if time_min <= 0: + raise ValueError(f'time_min should be positive, current value {time_min}') + + if time_max <= time_min: + raise ValueError(f'time_max should be larger than time_min, current max {time_max} and min {time_min}') + + self.time_min = time_min + self.time_max = time_max + + # number of steps + if num_steps <= 0: + raise ValueError(f'num_steps needs to be positive: current value {num_steps}') + + self.num_steps = num_steps + + @property + def dt(self) -> float: + """Time step for this SDE. + This denotes the step size between `0` and `self.time_max` when using `self.num_steps`. + """ + return self.time_max / self.num_steps + + @property + def time_delta(self) -> float: + """Time range for this SDE. + """ + return self.time_max - self.time_min + + def generate_time(self, size: int, device: torch.device) -> torch.Tensor: + """Generate random time steps in the valid range. + + Time steps are generated between `self.time_min` and `self.time_max`. + + Args: + size: number of samples + device: device to use + + Returns: + A tensor of floats with shape (size,) + """ + time = torch.rand(size, device=device) * self.time_delta + self.time_min + return time + + @abstractmethod + def coefficients(self, state: torch.Tensor, time: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + state: tensor of shape (B, C, D, T) + time: tensor of shape (B,) + + Returns: + Tuple with drift and diffusion coefficients. + """ + pass + + @typecheck( + input_types={"prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, + output_types={"sample": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, + ) + @abstractmethod + def prior_sampling(self, prior_mean: torch.Tensor) -> torch.Tensor: + """Generate a sample from the prior distribution p_T. + + Args: + prior_mean: Mean of the prior distribution + + Returns: + A sample from the prior distribution. + """ + pass + + def discretize( + self, *, state: torch.Tensor, time: torch.Tensor, state_length: Optional[torch.Tensor] = None, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Assume we have the following SDE: + + dx = drift(x, t) * dt + diffusion(x, t) * dwt + + where `wt` is the standard Wiener process. + + We assume the following discretization: + + new_state = current_state + total_drift + total_diffusion * z_norm + + where `z_norm` is sampled from normal distribution with zero mean and unit variance. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + state_length: length of the valid time steps for each example in the batch, shape (B,) + **kwargs: other parameters + + Returns: + Drift and diffusion. + """ + # Get coefficients + drift_coefficient, diffusion_coefficient = self.coefficients( + state=state, time=time, state_length=state_length, **kwargs + ) + + # Discretized drift + drift = drift_coefficient * self.dt + + # Note: + # Scale with sqrt(dt) because z_norm is sampled from a normal distribution with zero mean and + # unit variance and dwt is normally distributed with zero mean and variance dt + diffusion = diffusion_coefficient * np.sqrt(self.dt) + + return drift, diffusion + + @abstractmethod + def copy(self): + """Create a copy of this SDE. + """ + pass + + def __repr__(self): + desc = f'{self.__class__.__name__}(time_min={self.time_min}, time_max={self.time_max}, num_steps={self.num_steps})' + desc += f'\n\tdt: {self.dt}' + desc += f'\n\ttime_delta: {self.time_delta}' + return desc + + +class OrnsteinUhlenbeckVarianceExplodingSDE(StochasticDifferentialEquation): + """This class implements the Ornstein-Uhlenbeck SDE with variance exploding noise schedule. + + The SDE is given by: + + dx = theta * (y - x) dt + g(t) dw + + where `theta` is the stiffness parameter and `g(t)` is the diffusion coefficient: + + g(t) = std_min * (std_max/std_min)^t * sqrt(2 * log(std_max/std_min)) + + References: + Richter et al., Speech Enhancement and Dereverberation with Diffusion-based Generative Models, Tr. ASLP 2023 + """ + + def __init__( + self, + stiffness: float, + std_min: float, + std_max: float, + num_steps: int = 100, + time_min: float = 3e-2, + time_max: float = 1.0, + eps: float = 1e-8, + ): + super().__init__(time_min=time_min, time_max=time_max, num_steps=num_steps) + + # Small regularization + if eps <= 0: + raise ValueError(f'eps should be positive, current value {eps}') + self.eps = eps + + # stifness + self.stiffness = stiffness + + # noise schedule + if std_min <= 0: + raise ValueError(f'std_min should be positive, current value {std_min}') + + if std_max <= std_min: + raise ValueError(f'std_max should be larger than std_min, current max {std_max} and min {std_min}') + + self.std_min = std_min + self.std_max = std_max + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tstiffness: %s', self.stiffness) + logging.debug('\tstd_min: %s', self.std_min) + logging.debug('\tstd_max: %s', self.std_max) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\teps: %s', self.eps) + + @property + def std_ratio(self) -> float: + return self.std_max / (self.std_min + self.eps) + + @property + def log_std_ratio(self) -> float: + return np.log(self.std_ratio + self.eps) + + @typecheck( + input_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "time": NeuralType(tuple('B'), FloatType()), + }, + output_types={"mean": NeuralType(('B', 'C', 'D', 'T'), FloatType()),}, + ) + def perturb_kernel_mean(self, state: torch.Tensor, prior_mean: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + """Return the mean of the perturbation kernel for this SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + prior_mean: mean of the prior distribution + time: current time of the process, shape (B,) + + Returns: + A tensor of shape (B, C, D, T) + """ + # exponential weighting + weight = torch.exp(-self.stiffness * time) + + # view as [B, C, D, T] + weight = weight.view(-1, 1, 1, 1) + + # closed-form mean + mean = weight * state + (1 - weight) * prior_mean + + return mean + + @typecheck( + input_types={"time": NeuralType(tuple('B'), FloatType()),}, + output_types={"std": NeuralType(tuple('B'), FloatType()),}, + ) + def perturb_kernel_std(self, time: torch.Tensor) -> torch.Tensor: + """Return the standard deviation of the perturbation kernel for this SDE. + + Note that the standard deviation depends on the time and the noise schedule, + which is parametrized using `self.stiffness`, `self.std_min` and `self.std_max`. + + Args: + time: current time of the process, shape (B,) + + Returns: + A tensor of shape (B,) + """ + var = (self.std_min ** 2) * self.log_std_ratio + var *= torch.pow(self.std_ratio, 2 * time) - torch.exp(-2 * self.stiffness * time) + var /= self.stiffness + self.log_std_ratio + std = torch.sqrt(var) + return std + + @typecheck( + input_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "time": NeuralType(tuple('B'), FloatType()), + }, + output_types={ + "mean": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + "std": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + }, + ) + def perturb_kernel_params(self, state: torch.Tensor, prior_mean: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + """Return the mean and standard deviation of the perturbation kernel for this SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + prior_mean: mean of the prior distribution + time: current time of the process, shape (B,) + """ + assert torch.all(time <= self.time_max) + assert torch.all(time >= self.time_min) + + # compute the mean + mean = self.perturb_kernel_mean(state=state, prior_mean=prior_mean, time=time) + + # compute the standard deviation + std = self.perturb_kernel_std(time=time) + # view as [B, C, D, T] + std = std.view(-1, 1, 1, 1) + + return mean, std + + @typecheck( + input_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "time": NeuralType(tuple('B'), VoidType()), + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={ + "drift_coefficient": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + "diffusion_coefficient": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + }, + ) + def coefficients( + self, + state: torch.Tensor, + time: torch.Tensor, + prior_mean: torch.Tensor, + state_length: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute drift and diffusion coefficients for this SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + prior_mean: mean of the prior distribution + state_length: length of the valid time steps for each example in the batch + + Returns: + Drift and diffusion coefficients. + """ + # Drift coefficient + drift_coefficient = self.stiffness * (prior_mean - state) + + # Diffusion coefficient + diffusion_coefficient = self.std_min * torch.pow(self.std_ratio, time) * np.sqrt(2 * self.log_std_ratio) + # View in the same shape as the state + diffusion_coefficient = diffusion_coefficient.view(-1, *([1] * (state.dim() - 1))) + + if state_length is not None: + drift_coefficient = mask_sequence_tensor(drift_coefficient, state_length) + diffusion_coefficient = mask_sequence_tensor(diffusion_coefficient, state_length) + + return drift_coefficient, diffusion_coefficient + + def prior_sampling(self, prior_mean: torch.Tensor) -> torch.Tensor: + """Generate a sample from the prior distribution p_T. + + Args: + prior_mean: Mean of the prior distribution + """ + # Final time step for all samples in the batch + time = self.time_max * torch.ones(prior_mean.shape[0], device=prior_mean.device) + + # Compute the std of the prior distribution + std = self.perturb_kernel_std(time=time) + + # view as [B, C, D, T] + std = std.view(-1, 1, 1, 1) + + # Generate a sample from a normal distribution centered at prior_mean + sample = prior_mean + torch.randn_like(prior_mean) * std + + return sample + + def copy(self): + return OrnsteinUhlenbeckVarianceExplodingSDE( + stiffness=self.stiffness, + std_min=self.std_min, + std_max=self.std_max, + num_steps=self.num_steps, + time_min=self.time_min, + time_max=self.time_max, + eps=self.eps, + ) + + def __repr__(self): + desc = f'{self.__class__.__name__}(stiffness={self.stiffness}, std_min={self.std_min}, std_max={self.std_max}, num_steps={self.num_steps}, time_min={self.time_min}, time_max={self.time_max}, eps={self.eps})' + desc += f'\n\tdt: {self.dt}' + desc += f'\n\ttime_delta: {self.time_delta}' + desc += f'\n\tstd_ratio: {self.std_ratio}' + desc += f'\n\tlog_std_ratio: {self.log_std_ratio}' + + return desc + + +class ReverseStochasticDifferentialEquation(StochasticDifferentialEquation): + def __init__(self, *, sde: Type[StochasticDifferentialEquation], score_estimator: Type[NeuralModule]): + """Use the forward SDE and a score estimator to define the reverse SDE. + + Args: + sde: forward SDE + score_estimator: neural score estimator + """ + super().__init__(time_min=sde.time_min, time_max=sde.time_max, num_steps=sde.num_steps) + self.score_estimator = score_estimator + self.forward_sde = sde + + logging.debug('Initialized %s', self.__class__.__name__) + + def coefficients( + self, + state: torch.Tensor, + time: torch.Tensor, + score_condition: Optional[torch.Tensor] = None, + state_length: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute drift and diffusion coefficients for the reverse SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + """ + raise NotImplementedError('Coefficients not necessary for the reverse SDE.') + + def prior_sampling(self, shape: torch.Size, device: torch.device) -> torch.Tensor: + """Prior sampling is not necessary for the reverse SDE. + """ + raise NotImplementedError('Prior sampling not necessary for the reverse SDE.') + + def discretize( + self, + *, + state: torch.Tensor, + time: torch.Tensor, + score_condition: Optional[torch.Tensor] = None, + state_length: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Discretize the reverse SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: condition for the score estimator + state_length: length of the valid time steps for each example in the batch + **kwargs: other parameters for discretization of the forward SDE + """ + # Drift and diffusion from the forward SDE + forward_drift, forward_diffusion = self.forward_sde.discretize(state=state, time=time, **kwargs) + + # For input for the score estimator: + # - if no condition is provided, use the state + # - if a condition is provided, concatenate the state and the condition along the channel dimension + score_input = state if score_condition is None else torch.cat([state, score_condition], dim=1) + + # Estimate score + score, _ = self.score_estimator(input=score_input, input_length=state_length, condition=time) + + # Adjust drift + drift = forward_drift - forward_diffusion.pow(2) * score + + # Adjust diffusion + diffusion = forward_diffusion + + if state_length is not None: + drift = mask_sequence_tensor(drift, state_length) + diffusion = mask_sequence_tensor(diffusion, state_length) + + return drift, diffusion + + def copy(self): + return ReverseStochasticDifferentialEquation(sde=self.forward_sde.copy(), score_estimator=self.score_estimator) + + def __repr__(self): + desc = f'{self.__class__.__name__}(sde={self.forward_sde}, score_estimator={self.score_estimator})' + return desc + + +class SpectrogramNoiseConditionalScoreNetworkPlusPlus(NeuralModule): + """This model handles complex-valued inputs by stacking real and imaginary components. + Stacked tensor is processed using NCSN++ and the output is projected to generate real + and imaginary components of the output channels. + + Args: + in_channels: number of input complex-valued channels + out_channels: number of output complex-valued channels + """ + + def __init__(self, *, in_channels: int = 1, out_channels: int = 1, **kwargs): + super().__init__() + + # Number of input signals for this estimator + if in_channels < 1: + raise ValueError( + f'Number of input channels needs to be larger or equal to one, current value {in_channels}' + ) + + self.in_channels = in_channels + + # Number of output signals for this estimator + if out_channels < 1: + raise ValueError( + f'Number of output channels needs to be larger or equal to one, current value {out_channels}' + ) + + self.out_channels = out_channels + + # Instantiate noise conditional score network NCSN++ + ncsnpp_params = kwargs.copy() + ncsnpp_params['in_channels'] = ncsnpp_params['out_channels'] = 2 * self.in_channels # stack real and imag + self.ncsnpp = NoiseConditionalScoreNetworkPlusPlus(**ncsnpp_params) + + # Output projection to generate real and imaginary components of the output channels + self.output_projection = torch.nn.Conv2d( + in_channels=2 * self.in_channels, out_channels=2 * self.out_channels, kernel_size=1 + ) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_channels: %s', self.in_channels) + logging.debug('\tout_channels: %s', self.out_channels) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + "condition": NeuralType(('B',), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @typecheck() + def forward(self, input, input_length=None, condition=None): + # Stack real and imaginary components + B, C_in, D, T = input.shape + + if C_in != self.in_channels: + raise RuntimeError(f'Unexpected input channel size {C_in}, expected {self.in_channels}') + + # Stack real and imaginary parts + input_real_imag = torch.stack([input.real, input.imag], dim=2) + input = einops.rearrange(input_real_imag, 'B C RI F T -> B (C RI) F T') + + # Process using NCSN++ + output, output_length = self.ncsnpp(input=input, input_length=input_length, condition=condition) + + # Output projection + output = self.output_projection(output) + + # Convert to complex-valued signal + output = output.reshape(B, 2, self.out_channels, D, T) + # Move real/imag dimension to the end + output = output.permute(0, 2, 3, 4, 1) + output = torch.view_as_complex(output.contiguous()) + + return output, output_length + + +class NoiseConditionalScoreNetworkPlusPlus(NeuralModule): + """Implementation of Noise Conditional Score Network (NCSN++) architecture. + + References: + - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021 + - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018 + """ + + def __init__( + self, + nonlinearity: str = "swish", + in_channels: int = 2, # number of channels in the input image + out_channels: int = 2, # number of channels in the output image + channels: Sequence[int] = (128, 128, 256, 256, 256), # number of channels at start + at every resolution + num_res_blocks: int = 2, + num_resolutions: int = 4, + init_scale: float = 1e-5, + conditioned_on_time: bool = False, + fourier_embedding_scale: float = 16.0, + dropout_rate: float = 0.0, + pad_time_to: Optional[int] = None, + pad_dimension_to: Optional[int] = None, + **_, + ): + # Network topology is a flavor of UNet, example chart for num_resolutions=4 + # + # 1: Image → Image/2 → Image/4 → Image/8 + # ↓ ↓ ↓ ↓ + # 2: Hidden → Hidden/2 → Hidden/4 → Hidden/8 + # ↓ ↓ ↓ ↓ + # 3: Hidden ← Hidden/2 ← Hidden/4 ← Hidden/8 + # ↓ ↓ ↓ ↓ + # 4: Image ← Image/2 ← Image/4 ← Image/8 + + # Horizontal arrows in (1) are downsampling + # Vertical arrows from (1) to (2) are channel upconversions + # + # Horizontal arrows in (2) are blocks with downsampling where necessary + # Horizontal arrows in (3) are blocks with upsampling where necessary + # + # Vertical arrows from (1) to (2) are downsampling and channel upconversioins + # Vertical arrows from (2) to (3) are sums connections (also with / sqrt(2)) + # Vertical arrows from (3) to (4) are channel downconversions + # Horizontal arrows in (4) are upsampling and addition + super().__init__() + + # same nonlinearity is used throughout the whole network + self.activation: torch.nn.Module = activation_registry[nonlinearity]() + self.init_scale: float = init_scale + + self.downsample = torch.nn.Upsample(scale_factor=0.5, mode="bilinear") + self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear") + + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_blocks = num_res_blocks + self.num_resolutions = num_resolutions + self.conditioned_on_time = conditioned_on_time + + # padding setup + self.pad_time_to = pad_time_to or 2 ** self.num_resolutions + self.pad_dimension_to = pad_dimension_to or 2 ** self.num_resolutions + + if self.conditioned_on_time: + self.time_embedding = torch.nn.Sequential( + GaussianFourierProjection(embedding_size=self.channels[0], scale=fourier_embedding_scale), + torch.nn.Linear(self.channels[0] * 2, self.channels[0] * 4), + self.activation, + torch.nn.Linear(self.channels[0] * 4, self.channels[0] * 4), + ) + + self.input_pyramid = torch.nn.ModuleList() + for ch in self.channels[:-1]: + self.input_pyramid.append(torch.nn.Conv2d(in_channels=self.in_channels, out_channels=ch, kernel_size=1)) + + # each block takes an image and outputs an image + # possibly changes number of channels + # output blocks ("reverse" path of the unet) reuse outputs of input blocks ("forward" path) + # so great care must be taken to in/out channels of each block + # resolutions are handled in `forward` + block_params = { + "activation": self.activation, + "dropout_rate": dropout_rate, + "init_scale": self.init_scale, + "diffusion_step_embedding_dim": channels[0] * 4 if self.conditioned_on_time else None, + } + self.input_blocks = torch.nn.ModuleList() + for in_ch, out_ch in zip(self.channels[:-1], self.channels[1:]): + for n in range(num_res_blocks): + block = ResnetBlockBigGANPlusPlus(in_ch=in_ch if n == 0 else out_ch, out_ch=out_ch, **block_params) + self.input_blocks.append(block) + + self.output_blocks = torch.nn.ModuleList() + for in_ch, out_ch in zip(reversed(self.channels[1:]), reversed(self.channels[:-1])): + for n in reversed(range(num_res_blocks)): + block = ResnetBlockBigGANPlusPlus(in_ch=in_ch, out_ch=out_ch if n == 0 else in_ch, **block_params) + self.output_blocks.append(block) + + self.projection_blocks = torch.nn.ModuleList() + for ch in self.channels[:-1]: + self.projection_blocks.append(torch.nn.Conv2d(ch, out_channels, kernel_size=1)) + + assert len(self.input_pyramid) == self.num_resolutions + assert len(self.input_blocks) == self.num_resolutions * self.num_res_blocks + assert len(self.output_blocks) == self.num_resolutions * self.num_res_blocks + assert len(self.projection_blocks) == self.num_resolutions + + self.init_weights_() + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_channels: %s', self.in_channels) + logging.debug('\tout_channels: %s', self.out_channels) + logging.debug('\tchannels: %s', self.channels) + logging.debug('\tnum_res_blocks: %s', self.num_res_blocks) + logging.debug('\tnum_resolutions: %s', self.num_resolutions) + logging.debug('\tconditioned_on_time: %s', self.conditioned_on_time) + logging.debug('\tpad_time_to: %s', self.pad_time_to) + logging.debug('\tpad_dimension_to: %s', self.pad_dimension_to) + + def init_weights_(self): + for module in self.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + # torch.nn submodules with scaled init + for module in self.projection_blocks: + torch.nn.init.xavier_uniform_(module.weight, gain=self.init_scale) + + # non-torch.nn submodules can have their own init schemes + for module in self.modules(): + if module is self: + continue + + if hasattr(module, "init_weights_"): + module.init_weights_() + + @typecheck( + input_types={"input": NeuralType(('B', 'C', 'D', 'T')),}, + output_types={"output": NeuralType(('B', 'C', 'D', 'T')),}, + ) + def pad_input(self, input: torch.Tensor) -> torch.Tensor: + """Pad input tensor to match the required dimensions across `T` and `D`. + """ + *_, D, T = input.shape + output = input + + # padding across time + if T % self.pad_time_to != 0: + output = F.pad(output, (0, self.pad_time_to - T % self.pad_time_to)) + + # padding across dimension + if D % self.pad_dimension_to != 0: + output = F.pad(output, (0, 0, 0, self.pad_dimension_to - D % self.pad_dimension_to)) + + return output + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "input": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + "condition": NeuralType(('B',), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "output": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @typecheck() + def forward( + self, *, input: torch.Tensor, input_length: Optional[torch.Tensor], condition: Optional[torch.Tensor] = None + ): + """Forward pass of the model. + + Args: + input: input tensor, shjae (B, C, D, T) + input_length: length of the valid time steps for each example in the batch, shape (B,) + condition: scalar condition (time) for the model, will be embedded using `self.time_embedding` + """ + assert input.shape[1] == self.in_channels + + # apply padding at the input + *_, D, T = input.shape + input = self.pad_input(input=input) + + if input_length is None: + # assume all time frames are valid + input_length = torch.LongTensor([input.shape[-1]] * input.shape[0]).to(input.device) + + lengths = input_length + + if condition is not None: + if len(condition.shape) != 1: + raise ValueError( + f"Expected conditon to be a 1-dim tensor, got a {len(condition.shape)}-dim tensor of shape {tuple(condition.shape)}" + ) + if condition.shape[0] != input.shape[0]: + raise ValueError( + f"Condition {tuple(condition.shape)} and input {tuple(input.shape)} should match along the batch dimension" + ) + + condition = self.time_embedding(torch.log(condition)) + + # downsample and project input image to add later in the downsampling path + pyramid = [input] + for resolution_num in range(self.num_resolutions - 1): + pyramid.append(self.downsample(pyramid[-1])) + pyramid = [block(image) for image, block in zip(pyramid, self.input_pyramid)] + + # downsampling path + history = [] + hidden = torch.zeros_like(pyramid[0]) + input_blocks = iter(self.input_blocks) + for resolution_num, image in enumerate(pyramid): + hidden = (hidden + image) / math.sqrt(2.0) + hidden = mask_sequence_tensor(hidden, lengths) + + for _ in range(self.num_res_blocks): + hidden = next(input_blocks)(hidden, condition) + hidden = mask_sequence_tensor(hidden, lengths) + history.append(hidden) + + final_resolution = resolution_num == self.num_resolutions - 1 + if not final_resolution: + hidden = self.downsample(hidden) + lengths = (lengths / 2).ceil().long() + + # upsampling path + to_project = [] + for residual, block in zip(reversed(history), self.output_blocks): + if hidden.shape != residual.shape: + to_project.append(hidden) + hidden = self.upsample(hidden) + lengths = (lengths * 2).long() + + hidden = (hidden + residual) / math.sqrt(2.0) + hidden = block(hidden, condition) + hidden = mask_sequence_tensor(hidden, lengths) + + to_project.append(hidden) + + # projecting to images + images = [] + for tensor, projection in zip(to_project, reversed(self.projection_blocks)): + image = projection(tensor) + images.append(F.interpolate(image, size=input.shape[-2:])) # TODO write this loop using self.upsample + + result = sum(images) + + assert result.shape[-2:] == input.shape[-2:] + + # remove padding + result = result[:, :, :D, :T] + return result, input_length + + +class GaussianFourierProjection(NeuralModule): + """Gaussian Fourier embeddings for input scalars. + + The input scalars are typically time or noise levels. + """ + + def __init__(self, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.W = torch.nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "input": NeuralType(('B',), FloatType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "output": NeuralType(('B', 'D'), VoidType()), + } + + def forward(self, input): + x_proj = input[:, None] * self.W[None, :] * 2 * math.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +class ResnetBlockBigGANPlusPlus(torch.nn.Module): + """Implementation of a ResNet block for the BigGAN model. + + References: + - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021 + - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018 + """ + + def __init__( + self, + activation: torch.nn.Module, + in_ch: int, + out_ch: int, + diffusion_step_embedding_dim: Optional[int] = None, + init_scale: float = 1e-5, + dropout_rate: float = 0.1, + in_num_groups: Optional[int] = None, + out_num_groups: Optional[int] = None, + eps: float = 1e-6, + ): + """ + Args: + activation (torch.nn.Module): activation layer (ReLU, SiLU, etc) + in_ch (int): number of channels in the input image + out_ch (int, optional): number of channels in the output image + diffusion_step_embedding_dim (int, optional): dimension of diffusion timestep embedding. Defaults to None (no embedding). + dropout_rate (float, optional): dropout rate. Defaults to 0.1. + init_scale (float, optional): scaling for weight initialization. Defaults to 0.0. + in_num_groups (int, optional): num_groups in the first GroupNorm. Defaults to min(in_ch // 4, 32) + out_num_groups (int, optional): num_groups in the second GroupNorm. Defaults to min(out_ch // 4, 32) + eps (float, optional): eps parameter of GroupNorms. Defaults to 1e-6. + """ + super().__init__() + in_num_groups = in_num_groups or min(in_ch // 4, 32) + out_num_groups = out_num_groups or min(out_ch // 4, 32) + + self.init_scale = init_scale + + self.input_block = torch.nn.Sequential( + torch.nn.GroupNorm(num_groups=in_num_groups, num_channels=in_ch, eps=eps), activation, + ) + + self.middle_conv = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1) + if diffusion_step_embedding_dim is not None: + self.diffusion_step_projection = torch.nn.Sequential( + activation, + torch.nn.Linear(diffusion_step_embedding_dim, out_ch), + einops.layers.torch.Rearrange("batch dim -> batch dim 1 1"), + ) + + self.output_block = torch.nn.Sequential( + torch.nn.GroupNorm(num_groups=out_num_groups, num_channels=out_ch, eps=eps), + activation, + torch.nn.Dropout(dropout_rate), + torch.nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1), + ) + + if in_ch != out_ch: + self.residual_projection = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1) + + self.act = activation + self.in_ch = in_ch + self.out_ch = out_ch + + self.init_weights_() + + def init_weights_(self): + """Weight initialization + """ + for module in self.modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + # a single Conv2d is initialized with gain + torch.nn.init.xavier_uniform_(self.output_block[-1].weight, gain=self.init_scale) + + def forward(self, x: torch.Tensor, diffusion_time_embedding: Optional[torch.Tensor] = None): + """Forward pass of the model. + + Args: + x: input tensor + diffusion_time_embedding: embedding of the diffusion time step + + Returns: + Output tensor + """ + h = self.input_block(x) + h = self.middle_conv(h) + + if diffusion_time_embedding is not None: + h = h + self.diffusion_step_projection(diffusion_time_embedding) + + h = self.output_block(h) + + if x.shape != h.shape: # matching number of channels + x = self.residual_projection(x) + return (x + h) / math.sqrt(2.0) + + +class PredictorCorrectorSampler(NeuralModule): + """Predictor-Corrector sampler for the reverse SDE. + + Args: + sde: forward SDE + score_estimator: neural score estimator + predictor: predictor for the reverse process + corrector: corrector for the reverse process + num_steps: number of time steps for the reverse process + num_corrector_steps: number of corrector steps + time_max: maximum time + time_min: minimum time + snr: SNR for Annealed Langevin Dynamics + output_type: type of the output ('state' for the final state, or 'mean' for the mean of the final state) + + References: + - Song et al., Score-based generative modeling through stochastic differential equations, 2021 + """ + + def __init__( + self, + sde, + score_estimator, + predictor: str = 'reverse_diffusion', + corrector: str = 'annealed_langevin_dynamics', + num_steps: int = 50, + num_corrector_steps: int = 1, + time_max: Optional[float] = None, + time_min: Optional[float] = None, + snr: float = 0.5, + output_type: str = 'mean', + ): + super().__init__() + # Create a copy of SDE + self.sde = sde.copy() + + # Update SDE parameters for sampling + if time_max is not None: + self.sde.time_max = time_max + logging.info('sde.time_max set to: %s', self.sde.time_max) + + if time_min is not None: + self.sde.time_min = time_min + logging.info('sde.time_min set to: %s', self.sde.time_min) + + self.sde.num_steps = num_steps + logging.info('sde.num_steps set to: %s', self.sde.num_steps) + + # Update local values + self.time_max = self.sde.time_max + self.time_min = self.sde.time_min + self.num_steps = self.sde.num_steps + + # Predictor setup + if predictor == 'reverse_diffusion': + self.predictor = ReverseDiffusionPredictor(sde=self.sde, score_estimator=score_estimator) + else: + raise RuntimeError(f'Unexpected predictor: {predictor}') + + # Corrector setup + if corrector == 'annealed_langevin_dynamics': + self.corrector = AnnealedLangevinDynamics( + sde=self.sde, score_estimator=score_estimator, snr=snr, num_steps=num_corrector_steps + ) + else: + raise RuntimeError(f'Unexpected corrector: {corrector}') + + if output_type not in ['mean', 'state']: + raise ValueError(f'Unexpected output type: {output_type}') + self.output_type = output_type + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tpredictor: %s', predictor) + logging.debug('\tcorrector: %s', corrector) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\tnum_corrector_steps: %s', num_corrector_steps) + logging.debug('\tsnr: %s', snr) + logging.debug('\toutput_type: %s', self.output_type) + + @typecheck( + input_types={ + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "score_condition": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType(), optional=True), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={ + "sample": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + ) + @torch.inference_mode() + def forward( + self, prior_mean: torch.Tensor, score_condition: torch.Tensor, state_length: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Takes prior (noisy) mean and generates a sample by solving the reverse SDE. + + Args: + prior_mean: mean for the prior distribution, e.g., noisy observation + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + Generated `sample` and the corresponding `sample_length`. + """ + # Sample from the prior distribution + state = self.sde.prior_sampling(prior_mean=prior_mean) + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + + # Time steps for evaluation + time_steps = torch.linspace(self.time_max, self.time_min, self.num_steps, device=state.device) + + # Sampling + for t in time_steps: + # time steps for the whole batch + time = t * torch.ones(state.shape[0], device=state.device) + + # corrector step + state, _ = self.corrector( + state=state, time=time, score_condition=score_condition, state_length=state_length + ) + + # predictor step + state, state_mean = self.predictor( + state=state, + time=time, + score_condition=score_condition, + prior_mean=prior_mean, + state_length=state_length, + ) + + # Final output + if self.output_type == 'state': + sample = state + elif self.output_type == 'mean': + sample = state_mean + else: + raise RuntimeError(f'Unexpected output type: {self.output_type}') + + if state_length is not None: + sample = mask_sequence_tensor(sample, state_length) + + return sample, state_length + + +class Predictor(torch.nn.Module, ABC): + """Predictor for the reverse process. + + Args: + sde: forward SDE + score_estimator: neural score estimator + """ + + def __init__(self, sde, score_estimator): + super().__init__() + self.reverse_sde = ReverseStochasticDifferentialEquation(sde=sde, score_estimator=score_estimator) + + @abstractmethod + @torch.inference_mode() + def forward( + self, + *, + state: torch.Tensor, + time: torch.Tensor, + score_condition: Optional[torch.Tensor] = None, + state_length: Optional[torch.Tensor] = None, + **kwargs, + ): + """Predict the next state of the reverse process. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + New state and mean. + """ + pass + + +class ReverseDiffusionPredictor(Predictor): + """Predict the next state of the reverse process using the reverse diffusion process. + + Args: + sde: forward SDE + score_estimator: neural score estimator + """ + + def __init__(self, sde, score_estimator): + super().__init__(sde=sde, score_estimator=score_estimator) + + @torch.inference_mode() + def forward(self, *, state, time, score_condition=None, state_length=None, **kwargs): + """Predict the next state of the reverse process using the reverse diffusion process. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + New state and mean of the diffusion process. + """ + drift, diffusion = self.reverse_sde.discretize( + state=state, time=time, score_condition=score_condition, state_length=state_length, **kwargs + ) + + # Generate a random sample from a standard normal distribution + z_norm = torch.randn_like(state) + + # Compute the mean of the next state + mean = state - drift + + # Compute new state by sampling + new_state = mean + diffusion * z_norm + + if state_length is not None: + new_state = mask_sequence_tensor(new_state, state_length) + mean = mask_sequence_tensor(mean, state_length) + + return new_state, mean + + +class Corrector(NeuralModule, ABC): + """Corrector for the reverse process. + + Args: + sde: forward SDE + score_estimator: neural score estimator + snr: SNR for Annealed Langevin Dynamics + num_steps: number of steps for the corrector + """ + + def __init__( + self, + sde: Type[StochasticDifferentialEquation], + score_estimator: Type[NeuralModule], + snr: float, + num_steps: int, + ): + super().__init__() + self.sde = sde + self.score_estimator = score_estimator + self.snr = snr + self.num_steps = num_steps + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tsnr: %s', snr) + logging.debug('\tnum_steps: %s', num_steps) + + @abstractmethod + @typecheck( + input_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "time": NeuralType(tuple('B'), FloatType()), + "score_condition": NeuralType(('B', 'C', 'D', 'T'), VoidType(), optional=True), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={"state": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, + ) + @torch.inference_mode() + def forward(self, state, time, score_condition=None, state_length=None): + """ + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + New state and mean. + """ + pass + + +class AnnealedLangevinDynamics(Corrector): + """Annealed Langevin Dynamics for the reverse process. + + References: + - Song et al., Score-based generative modeling through stochastic differential equations, 2021 + """ + + def __init__(self, sde, **kwargs): + if not isinstance(sde, OrnsteinUhlenbeckVarianceExplodingSDE): + raise ValueError(f'Expected an instance of OrnsteinUhlenbeckVarianceExplodingSDE, got {type(sde)}') + super().__init__(sde=sde, **kwargs) + + @torch.inference_mode() + def forward(self, state, time, score_condition=None, state_length=None): + """Correct the state using Annealed Langevin Dynamics. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + New state and mean of the diffusion process. + + References: + Alg. 4 in http://arxiv.org/abs/2011.13456 + """ + # Compute the standard deviation of the diffusion process + std = self.sde.perturb_kernel_std(time=time) + # View as [B, 1, 1, 1] + std = std.view(-1, *([1] * (state.dim() - 1))) + + for i in range(self.num_steps): + # prepare input for the score estimator, concatenate conditioning along the channel dimension + score_input = state if score_condition is None else torch.cat([state, score_condition], dim=1) + + # calculate the score + score, _ = self.score_estimator(input=score_input, input_length=state_length, condition=time) + + # generate a sample from a standard normal distribution + z_norm = torch.randn_like(state) + + # compute the step size + # note: this is slightly different than in the paper, where std = ||z_norm||_2 / ||score||_2 + step_size = 2 * (self.snr * std).pow(2) + + # update the mean + mean = state + step_size * score + + # update the state + state = mean + z_norm * torch.sqrt(step_size * 2) + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + mean = mask_sequence_tensor(mean, state_length) + + return state, mean diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py index 7023f57652b5..6ea4314ab71f 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py @@ -1674,7 +1674,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # megatron_amp_O2 is not yet supported in diffusion models self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) - if self.cfg.precision in ['16', 16, 'bf16']: + if self.megatron_amp_O2 and self.cfg.precision in ['16', 16, 'bf16']: self.model_parallel_config.enable_autocast = False if not hasattr(self.cfg.unet_config, 'unet_precision') or not '16' in str( self.cfg.unet_config.unet_precision diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py index c92980d904f6..3fcab2127f4f 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import os from inspect import isfunction import torch @@ -21,6 +22,13 @@ from torch import einsum, nn from torch._dynamo import disable +if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1": + from nemo.gn_native import GroupNormNormlization as GroupNorm +else: + from apex.contrib.group_norm import GroupNorm + +from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP + from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, @@ -96,13 +104,19 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=False): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - self.net = nn.Sequential(project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out)) + if use_te: + activation = 'gelu' if not glu else 'geglu' + # TODO: more parameters to be confirmed, dropout, seq_length + self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,) + else: + norm = nn.LayerNorm(dim) + project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + self.net = nn.Sequential(norm, project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -225,10 +239,15 @@ def __init__( dropout=0.0, use_flash_attention=False, lora_network_alpha=None, + use_te=False, ): super().__init__() self.inner_dim = dim_head * heads + if context_dim is None: + self.is_self_attn = True + else: + self.is_self_attn = False # cross-attention context_dim = default(context_dim, query_dim) # make attention part be aware of self-attention/cross-attention self.context_dim = context_dim @@ -238,10 +257,19 @@ def __init__( self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) + self.use_te = use_te + if use_te: + return_layernorm_output = True if self.is_self_attn else False + self.norm_to_q = LayerNormLinear( + query_dim, self.inner_dim, bias=False, return_layernorm_output=return_layernorm_output + ) + else: + self.norm = nn.LayerNorm(query_dim) + self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False) + self.to_out = nn.Sequential( LinearWrapper(self.inner_dim, query_dim, lora_network_alpha=lora_network_alpha), nn.Dropout(dropout) ) @@ -262,8 +290,18 @@ def forward(self, x, context=None, mask=None, additional_tokens=None, n_times_cr # add additional token x = torch.cat([additional_tokens, x], dim=1) - q = self.to_q(x) - context = default(context, x) + if self.use_te: + q_out = self.norm_to_q(x) + if self.is_self_attn: + q, ln_out = q_out + context = default(context, ln_out) + else: + q = q_out + context = default(context, x) + else: + x = self.norm(x) + q = self.to_q(x) + context = default(context, x) k = self.to_k(context) v = self.to_v(context) @@ -351,6 +389,7 @@ def __init__( use_flash_attention=False, disable_self_attn=False, lora_network_alpha=None, + use_te=False, ): super().__init__() self.disable_self_attn = disable_self_attn @@ -362,8 +401,9 @@ def __init__( use_flash_attention=use_flash_attention, context_dim=context_dim if self.disable_self_attn else None, lora_network_alpha=lora_network_alpha, + use_te=use_te, ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, use_te=use_te) self.attn2 = CrossAttention( query_dim=dim, context_dim=context_dim, @@ -372,10 +412,8 @@ def __init__( dropout=dropout, use_flash_attention=use_flash_attention, lora_network_alpha=lora_network_alpha, + use_te=use_te, ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) self.use_checkpoint = use_checkpoint def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): @@ -397,15 +435,15 @@ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_at def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): x = ( self.attn1( - self.norm1(x), + x, context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, ) + x ) - x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x - x = self.ff(self.norm3(x)) + x + x = self.attn2(x, context=context, additional_tokens=additional_tokens) + x + x = self.ff(x) + x return x @@ -431,6 +469,7 @@ def __init__( use_checkpoint=False, use_flash_attention=False, lora_network_alpha=None, + use_te=False, ): super().__init__() logging.info( @@ -473,6 +512,7 @@ def __init__( use_flash_attention=use_flash_attention, disable_self_attn=disable_self_attn, lora_network_alpha=lora_network_alpha, + use_te=use_te, ) for d in range(depth) ] diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 5ff0f6aa8a8a..b610f921a22a 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import os +import re from abc import abstractmethod from collections.abc import Iterable +from contextlib import nullcontext from functools import partial from typing import Iterable @@ -22,6 +25,9 @@ import torch as th import torch.nn as nn import torch.nn.functional as F + +# FP8 related import +import transformer_engine from apex.contrib.group_norm import GroupNorm from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer @@ -62,6 +68,34 @@ def convert_module_to_fp32(module, enable_norm_layers=False): convert_module_to_dtype(module, torch.float32, enable_norm_layers) +def convert_module_to_fp8(model): + def _set_module(model, submodule_key, module): + tokens = submodule_key.split('.') + sub_tokens = tokens[:-1] + cur_mod = model + for s in sub_tokens: + cur_mod = getattr(cur_mod, s) + setattr(cur_mod, tokens[-1], module) + + import copy + + from transformer_engine.pytorch.module import Linear as te_Linear + + for n, v in model.named_modules(): + if isinstance(v, torch.nn.Linear): + # if n in ['class_embed', 'bbox_embed.layers.0', 'bbox_embed.layers.1', 'bbox_embed.layers.2']: continue + logging.info(f'[INFO] Replace Linear: {n}, weight: {v.weight.shape}') + if v.bias is None: + is_bias = False + else: + is_bias = True + newlinear = te_Linear(v.in_features, v.out_features, bias=is_bias) + newlinear.weight = copy.deepcopy(v.weight) + if v.bias is not None: + newlinear.bias = copy.deepcopy(v.bias) + _set_module(model, n, newlinear) + + class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py @@ -553,6 +587,7 @@ def __init__( unet_precision: str = "fp32", lora_network_alpha=None, timesteps=1000, + use_te_fp8: bool = False, ): super().__init__() from omegaconf.listconfig import ListConfig @@ -663,6 +698,7 @@ def __init__( input_block_chans = [model_channels] ch = model_channels ds = 1 + self.use_te_fp8 = use_te_fp8 for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ @@ -713,6 +749,7 @@ def __init__( use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, lora_network_alpha=lora_network_alpha, + use_te=self.use_te_fp8, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -778,6 +815,7 @@ def __init__( use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, + use_te=self.use_te_fp8, lora_network_alpha=lora_network_alpha, ), ResBlock( @@ -844,6 +882,7 @@ def __init__( use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, lora_network_alpha=lora_network_alpha, + use_te=self.use_te_fp8, ) ) if level and i == self.num_res_blocks[level]: @@ -899,6 +938,34 @@ def __init__( self.convert_to_fp16() elif unet_precision == 'fp16': self.convert_to_fp16(enable_norm_layers=True) + elif self.use_te_fp8: + assert unet_precision != 'fp16', "fp8 training can't work with fp16 O2 amp recipe" + convert_module_to_fp8(self) + + fp8_margin = int(os.getenv("FP8_MARGIN", '0')) + fp8_interval = int(os.getenv("FP8_INTERVAL", '1')) + fp8_format = os.getenv("FP8_FORMAT", "hybrid") + fp8_amax_history_len = int(os.getenv("FP8_HISTORY_LEN", '1024')) + fp8_amax_compute_algo = os.getenv("FP8_COMPUTE_ALGO", 'max') + fp8_wgrad = os.getenv("FP8_WGRAD", '1') == '1' + + fp8_format_dict = { + 'hybrid': transformer_engine.common.recipe.Format.HYBRID, + 'e4m3': transformer_engine.common.recipe.Format.E4M3, + } + fp8_format = fp8_format_dict[fp8_format] + + self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=fp8_margin, + interval=fp8_interval, + fp8_format=fp8_format, + amax_history_len=fp8_amax_history_len, + amax_compute_algo=fp8_amax_compute_algo, + override_linear_precision=(False, False, not fp8_wgrad), + ) + old_state_dict = self.state_dict() + new_state_dict = self.te_fp8_key_mapping(old_state_dict) + self.load_state_dict(new_state_dict, strict=False) self.unet_precision = unet_precision @@ -1000,8 +1067,65 @@ def _sdxl_embedding_mapping(self, sdxl_dict): res_dict[new_key_] = value_ return res_dict + def _legacy_unet_ckpt_mapping(self, unet_dict): + new_dict = {} + key_map = { + 'transformer_blocks.0.norm1.weight': 'transformer_blocks.0.attn1.norm.weight', + 'transformer_blocks.0.norm1.bias': 'transformer_blocks.0.attn1.norm.bias', + 'transformer_blocks.0.norm2.weight': 'transformer_blocks.0.attn2.norm.weight', + 'transformer_blocks.0.norm2.bias': 'transformer_blocks.0.attn2.norm.bias', + 'transformer_blocks.0.norm3.weight': 'transformer_blocks.0.ff.net.0.weight', + 'transformer_blocks.0.norm3.bias': 'transformer_blocks.0.ff.net.0.bias', + 'transformer_blocks.0.ff.net.0.proj.weight': 'transformer_blocks.0.ff.net.1.proj.weight', + 'transformer_blocks.0.ff.net.0.proj.bias': 'transformer_blocks.0.ff.net.1.proj.bias', + 'transformer_blocks.0.ff.net.2.weight': 'transformer_blocks.0.ff.net.3.weight', + 'transformer_blocks.0.ff.net.2.bias': 'transformer_blocks.0.ff.net.3.bias', + } + + pattern = re.compile(r'(input_blocks|output_blocks)\.[\d\w]+\.[\d\w]+\.') + pattern_middle_block = re.compile(r'middle_block\.[\d\w]+\.') + for old_key, value in unet_dict.items(): + match = pattern.match(old_key) + match_middle = pattern_middle_block.match(old_key) + if match or match_middle: + prefix = match.group(0) if match else match_middle.group(0) + suffix = old_key.split('.', 3)[-1] if match else old_key.split('.', 2)[-1] + if suffix in key_map: + new_key = prefix + key_map[suffix] + new_dict[new_key] = value + else: + new_dict[old_key] = value + else: + new_dict[old_key] = value + + return new_dict + + def te_fp8_key_mapping(self, unet_dict): + new_state_dict = {} + for key in unet_dict.keys(): + if 'extra_state' in key: + continue + + ### LayerNormLinear + # norm_to_q.layer_norm_{weight|bias} -> norm.{weight|bias} + # norm_to_q.weight -> to_q.weight + new_key = key.replace('attn1.norm.', 'attn1.norm_to_q.layer_norm_') + new_key = new_key.replace('attn1.to_q.weight', 'attn1.norm_to_q.weight',) + new_key = new_key.replace('attn2.norm.', 'attn2.norm_to_q.layer_norm_') + new_key = new_key.replace('attn2.to_q.weight', 'attn2.norm_to_q.weight',) + + ### LayerNormMLP + # ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias} + # ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias} + # ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias} + new_key = new_key.replace('ff.net.0.', 'ff.net.layer_norm_') + new_key = new_key.replace('ff.net.1.proj.', 'ff.net.fc1_') + new_key = new_key.replace('ff.net.3.', 'ff.net.fc2_') + + new_state_dict[new_key] = unet_dict[key] + return new_state_dict + def _state_key_mapping(self, state_dict: dict): - import re res_dict = {} input_dict = {} @@ -1027,13 +1151,7 @@ def _state_key_mapping(self, state_dict: dict): mid_dict = self._mid_blocks_mapping(mid_dict) other_dict = self._other_blocks_mapping(other_dict) sdxl_dict = self._sdxl_embedding_mapping(sdxl_dict) - # key_list = state_dict.keys() - # key_str = " ".join(key_list) - # for key_, val_ in state_dict.items(): - # key_ = key_.replace("down_blocks", "input_blocks")\ - # .replace("up_blocks", 'output_blocks') - # res_dict[key_] = val_ res_dict.update(input_dict) res_dict.update(output_dict) res_dict.update(mid_dict) @@ -1046,6 +1164,7 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from state_dict = self._strip_unet_key_prefix(state_dict) if not from_NeMo: state_dict = self._state_key_mapping(state_dict) + state_dict = self._legacy_unet_ckpt_mapping(state_dict) model_state_dict = self.state_dict() loaded_keys = [k for k in state_dict.keys()] @@ -1151,7 +1270,7 @@ def convert_to_fp16(self, enable_norm_layers=False): """ self.apply(lambda module: convert_module_to_fp16(module=module, enable_norm_layers=enable_norm_layers)) - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def _forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. @@ -1170,7 +1289,6 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - if self.unet_precision == "fp16-mixed" or self.unet_precision == "fp16": x = x.type(torch.float16) if context is not None: @@ -1197,6 +1315,13 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): else: return self.out(h) + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_te_fp8, fp8_recipe=self.fp8_recipe, + ) if self.use_te_fp8 else nullcontext(): + out = self._forward(x, timesteps, context, y, **kwargs) + return out + class EncoderUNetModel(nn.Module): """ diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index a5e886f3b479..16ded8e2c682 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -37,6 +37,8 @@ LoraDenseAttentionAdapterConfig, LoraHto4HAdapterConfig, LoraKQVAdapterConfig, + LoraUnfusedHto4HAdapterConfig, + LoraUnfusedKQVAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, PromptEncoderAdapterConfig, @@ -67,7 +69,12 @@ def mcore_register_adapters(self): Setup NeMo LoRA or IA3 adapter to this MCore layer. """ self.set_accepted_adapter_types( - [LoraKQVAdapterConfig._target_, LoraDenseAttentionAdapterConfig._target_, InfusedAdapterConfig._target_] + [ + LoraUnfusedKQVAdapterConfig._target_, + LoraKQVAdapterConfig._target_, + LoraDenseAttentionAdapterConfig._target_, + InfusedAdapterConfig._target_, + ] ) self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp if ( @@ -102,12 +109,20 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): # LoRA logic if self.is_adapter_available(): + lora_adapter = None lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER) + lora_unfused_kqv_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_KQV_ADAPTER) if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']: + lora_adapter = lora_kqv_adapter + if lora_unfused_kqv_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_KQV_ADAPTER]['enabled']: + assert lora_adapter is None, "Expected only one of lora_kqv_adapter or lora_unfused_kqv_adapter" + lora_adapter = lora_unfused_kqv_adapter + + if lora_adapter: if layernorm_output is not None: - lora_mixed_qkv = lora_kqv_adapter(layernorm_output) + lora_mixed_qkv = lora_adapter(layernorm_output) else: - lora_mixed_qkv = lora_kqv_adapter(hidden_states) + lora_mixed_qkv = lora_adapter(hidden_states) mixed_qkv = mixed_qkv + lora_mixed_qkv @@ -251,7 +266,12 @@ def mcore_register_adapters(self): Setup NeMo IA3 adapter to this MCore layer. """ self.set_accepted_adapter_types( - [LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_] + [ + LoraUnfusedHto4HAdapterConfig._target_, + LoraHto4HAdapterConfig._target_, + Lora4HtoHAdapterConfig._target_, + MLPInfusedAdapterConfig._target_, + ] ) # only self attn (packed qkv) for now self.linear_fc1.return_layernorm_output = True # need layernorm output for lora mlp if ( @@ -274,9 +294,17 @@ def forward(self, hidden_states): # LoRA logic if self.is_adapter_available(): - lora_linear_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER) - if lora_linear_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']: - lora_output = lora_linear_fc1_adapter(layernorm_output) + lora_adapter = None + lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER) + lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER) + if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']: + lora_adapter = lora_fc1_adapter + if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']: + assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER" + lora_adapter = lora_unfused_fc1_adapter + + if lora_adapter: + lora_output = lora_adapter(layernorm_output) intermediate_parallel = intermediate_parallel + lora_output if self.config.bias_activation_fusion: diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 5037bb1b3634..2a5372d11ab5 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -75,11 +75,13 @@ class AdapterName(str, enum.Enum): POST_ATTN_ADAPTER = 'adapter_2' PTUNING_ADAPTER = "ptuning_adapter" LORA_KQV_ADAPTER = "lora_kqv_adapter" + LORA_UNFUSED_KQV_ADAPTER = "lora_unfused_kqv_adapter" LORA_KV_ADAPTER = "lora_kv_adapter" LORA_Q_ADAPTER = "lora_q_adapter" MM_LINEAR_ADAPTER = "mm_linear_adapter" LORA_DENSE_ATTENTION_ADAPTER = "lora_dense_attention_adapter" LORA_Hto4H_ADAPTER = "lora_hto4h_adapter" + LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter" LORA_4HtoH_ADAPTER = "lora_4htoh_adapter" MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter" PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter" @@ -457,6 +459,183 @@ class Lora4HtoHAdapterConfig(ParallelLinearAdapterConfig): input_is_parallel: bool = True +class LoraUnfusedHto4HAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + in_features: int, + out_features: int, + dim: int, + activation: str = 'swish', + norm_position: Optional[str] = 'post', + norm_type: Optional[str] = 'mixedfusedlayernorm', + column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise. + row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise. + gather_output: bool = True, + input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers + dropout: float = 0.0, + model_parallel_config: Optional[ModelParallelConfig] = None, + alpha: float | None = None, + dropout_position: str = 'post', + a2a_experimental: bool = False, # TODO: should rename this or make it a default feature + **kwargs, + ): + super().__init__() + self.gate_adapter = ParallelLinearAdapter( + in_features, + out_features // 2, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + self.up_adapter = ParallelLinearAdapter( + in_features, + out_features // 2, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + + def forward(self, x): + gate_x = self.gate_adapter(x) + up_x = self.up_adapter(x) + x = torch.concat([gate_x, up_x], dim=2) + return x + + +@dataclass +class LoraUnfusedHto4HAdapterConfig(ParallelLinearAdapterConfig): + _target_: str = "{0}.{1}".format(LoraUnfusedHto4HAdapter.__module__, LoraUnfusedHto4HAdapter.__name__) + + +class LoraUnfusedKQVAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + in_features: int, + dim: int, + activation: str = 'swish', + norm_position: Optional[str] = 'post', + norm_type: Optional[str] = 'mixedfusedlayernorm', + column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise. + row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise. + gather_output: bool = True, + input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers + dropout: float = 0.0, + model_parallel_config: Optional[ModelParallelConfig] = None, + alpha: float | None = None, + dropout_position: str = 'post', + a2a_experimental: bool = False, # TODO: should rename this or make it a default feature + num_query_groups: Optional[int] = None, + kv_channels: Optional[int] = None, + **kwargs, + ): + super().__init__() + if num_query_groups is not None and kv_channels is not None: + out_features = kv_channels * num_query_groups + else: + out_features = in_features + + self.q_adapter = ParallelLinearAdapter( + in_features, + in_features, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + + self.k_adapter = ParallelLinearAdapter( + in_features, + out_features, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + self.v_adapter = ParallelLinearAdapter( + in_features, + out_features, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + + def forward(self, x): + qx = self.q_adapter(x) + kx = self.k_adapter(x) + vx = self.v_adapter(x) + x = torch.concat([qx, kx, vx], dim=2) + return x + + +@dataclass +class LoraUnfusedKQVAdapterConfig(AdapterConfig): + in_features: int + dim: int + activation: str = 'swish' + norm_position: Optional[str] = 'post' + norm_type: Optional[str] = 'mixedfusedlayernorm' + column_init_method: str = 'xavier' + row_init_method: str = 'zero' + gather_output: bool = True + input_is_parallel: bool = False + dropout: float = 0.0 + dropout_position: str = 'post' + alpha: float | None = None + network_alpha: int | None = None + a2a_experimental: bool = False + num_query_groups: Optional[int] = None + kv_channels: Optional[int] = None + _target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__) + + class PromptEncoderAdapter(nn.Module, AdapterModuleUtil): """ The Tensor Parallel MLP prompt encoder network that is used to generate the virtual diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index a6f68f0666b5..0a030759fe9b 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -73,6 +73,7 @@ try: from apex.transformer.pipeline_parallel.utils import get_num_microbatches + from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam HAVE_APEX = True @@ -1057,6 +1058,31 @@ def should_process(key): new_state_dict[key_] = state_dict[key_] state_dict = new_state_dict + if conf.get('unet_config') and conf.get('unet_config').get('use_te_fp8') == False: + # Mapping potential fp8 ckpt to fp16 model + # remove _extra_state in fp8 if there is. + new_state_dict = {} + for key in state_dict.keys(): + if 'extra_state' in key: + continue + + ### LayerNormLinear + # norm_to_q.layer_norm_{weight|bias} -> norm.{weight|bias} + # norm_to_q.weight -> to_q.weight + new_key = key.replace('norm_to_q.layer_norm_', 'norm.') + new_key = new_key.replace('norm_to_q.weight', 'to_q.weight') + + ### LayerNormMLP + # ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias} + # ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias} + # ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias} + new_key = new_key.replace('ff.net.layer_norm_', 'ff.net.0.') + new_key = new_key.replace('ff.net.fc1_', 'ff.net.1.proj.') + new_key = new_key.replace('ff.net.fc2_', 'ff.net.3.') + + new_state_dict[new_key] = state_dict[key] + state_dict = new_state_dict + return state_dict def _load_state_dict_from_disk(self, model_weights, map_location=None): diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 63caa409b218..47d5167d630e 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -36,6 +36,8 @@ LoraHto4HAdapterConfig, LoraKQVAdapterConfig, LoraKQVAdapterWeightTyingConfig, + LoraUnfusedHto4HAdapterConfig, + LoraUnfusedKQVAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, ParallelLinearAdapterWeightTyingConfig, @@ -132,11 +134,26 @@ def __init__(self, cfg): for module in target_modules: if module == PEFT_MODULE_MAP["qkv_module"]: - adapter_cfg = self._create_lora_config( - cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, LoraKQVAdapterConfig - ) - name_key_to_cfg[AdapterName.LORA_KQV_ADAPTER] = adapter_cfg - name_key_to_mcore_mixins[AdapterName.LORA_KQV_ADAPTER] = [("self_attention", MCoreSelfAttentionMixin)] + if lora_cfg.get("variant", "nemo") == "canonical": + _adapter_name = AdapterName.LORA_UNFUSED_KQV_ADAPTER + _adapter_cfg_cls = LoraUnfusedKQVAdapterConfig + adapter_cfg = self._create_lora_config( + cfg, + lora_cfg, + cfg.hidden_size, + qkv_projection_size, + _adapter_cfg_cls, + num_query_groups=num_query_groups, + kv_channels=kv_channels, + ) + else: + _adapter_name = AdapterName.LORA_KQV_ADAPTER + _adapter_cfg_cls = LoraKQVAdapterConfig + adapter_cfg = self._create_lora_config( + cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, _adapter_cfg_cls + ) + name_key_to_cfg[_adapter_name] = adapter_cfg + name_key_to_mcore_mixins[_adapter_name] = [("self_attention", MCoreSelfAttentionMixin)] elif module == PEFT_MODULE_MAP["dense_module"]: adapter_cfg = self._create_lora_config( @@ -149,11 +166,18 @@ def __init__(self, cfg): elif module == PEFT_MODULE_MAP["hto4h_module"]: hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size + if lora_cfg.get("variant", "nemo") == "canonical": + _adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER + _adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig + else: + _adapter_name = AdapterName.LORA_Hto4H_ADAPTER + _adapter_cfg_cls = LoraHto4HAdapterConfig + adapter_cfg = self._create_lora_config( - cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, LoraHto4HAdapterConfig + cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls ) - name_key_to_cfg[AdapterName.LORA_Hto4H_ADAPTER] = adapter_cfg - name_key_to_mcore_mixins[AdapterName.LORA_Hto4H_ADAPTER] = [("mlp", MCoreMLPMixin)] + name_key_to_cfg[_adapter_name] = adapter_cfg + name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] elif module == PEFT_MODULE_MAP["4htoh_module"]: adapter_cfg = self._create_lora_config( cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig @@ -170,7 +194,9 @@ def __init__(self, cfg): self.name_key_to_mcore_mixins = name_key_to_mcore_mixins super().__init__(lora_cfg, name_key_to_cfg) - def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls): + def _create_lora_config( + self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls, num_query_groups=None, kv_channels=None + ): config_args = { "in_features": in_features, "out_features": out_features, @@ -187,6 +213,12 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_ "a2a_experimental": lora_cfg.get("a2a_experimental", False), } + if adapter_cfg_cls == LoraUnfusedKQVAdapterConfig: + assert num_query_groups is not None, "num_query_groups must be provided for canonical Lora" + assert kv_channels is not None, "kv_channels must be provided for canonical Lora" + config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels}) + config_args.pop("out_features") + if lora_cfg.weight_tying: position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None) if position_embedding_strategy is None: diff --git a/nemo/export/quantize/quantizer.py b/nemo/export/quantize/quantizer.py index 2663f8fe9bac..783f47a08e79 100644 --- a/nemo/export/quantize/quantizer.py +++ b/nemo/export/quantize/quantizer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tarfile from contextlib import nullcontext from typing import List, Optional @@ -21,7 +20,6 @@ import torch.distributed as dist from megatron.core import parallel_state from megatron.core.transformer.module import Float16Module -from megatron.training.utils import unwrap_model from omegaconf import OmegaConf from omegaconf.omegaconf import DictConfig, open_dict from pytorch_lightning.trainer.trainer import Trainer @@ -31,7 +29,7 @@ from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging from nemo.utils.distributed import temporary_directory -from nemo.utils.model_utils import load_config, save_artifacts +from nemo.utils.model_utils import load_config, save_artifacts, unwrap_model try: import ammo.torch.quantization as atq diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 95d1bc414625..f4eefd39a9ea 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -24,7 +24,7 @@ from enum import Enum from functools import lru_cache from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import wrapt @@ -92,6 +92,24 @@ def load_config(model_file: str) -> DictConfig: return model_config +def unwrap_model(model, module_instances: Union[Type, Tuple[Type]]): + """Unwrap model from wrapper classes like Float16Module, for example.""" + + # TODO: Import this from megatron.core once moved there from megatron.training. + return_list = True + if not isinstance(model, list): + model = [model] + return_list = False + unwrapped_model = [] + for model_module in model: + while isinstance(model_module, module_instances): + model_module = model_module.module + unwrapped_model.append(model_module) + if not return_list: + return unwrapped_model[0] + return unwrapped_model + + def param_is_not_shared(param): return not hasattr(param, 'shared') or not param.shared diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index b7863714eb2d..30e839fd2ca8 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -1,5 +1,6 @@ braceexpand editdistance +einops g2p_en ipywidgets jiwer diff --git a/scripts/checkpoint_converters/convert_zarr_to_torch_dist.py b/scripts/checkpoint_converters/convert_zarr_to_torch_dist.py new file mode 100644 index 000000000000..29b56aa706fa --- /dev/null +++ b/scripts/checkpoint_converters/convert_zarr_to_torch_dist.py @@ -0,0 +1,193 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Conversion script to convert zarr checkpoints into torch distributed checkpoint. + Example to run this conversion script: + python -m torch.distributed.launch --nproc_per_node= * \ + megatron_zarr_ckpt_to_torch_dist.py \ + --model_type \ + --checkpoint_folder \ + --checkpoint_name \ + --path_to_save \ + --tensor_model_parallel_size \ + --pipeline_model_parallel_size \ + --hparams_file \ + --gpus_per_node +""" + +import os +from argparse import ArgumentParser + +import torch +from megatron.core import parallel_state +from omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.utils import AppState, logging +from nemo.utils.distributed import initialize_distributed + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--checkpoint_folder", + type=str, + default=None, + required=True, + help="Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints", + ) + parser.add_argument( + "--checkpoint_name", + type=str, + default=None, + required=True, + help="Name of checkpoint to be used. Ex: megatron_gpt--val_loss=0.14-step=20-consumed_samples=160.0-last", + ) + + parser.add_argument( + "--hparams_file", + type=str, + default=None, + required=True, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--path_to_save", type=str, default=None, required=True, help="Path to output ckpt files.") + parser.add_argument( + "--save_to_nemo", action="store_true", help="If passed, output will be written as .nemo file.", + ) + parser.add_argument("--gpus_per_node", type=int, required=True, default=None) + parser.add_argument("--tensor_model_parallel_size", type=int, required=True, default=None) + parser.add_argument("--pipeline_model_parallel_size", type=int, required=True, default=None) + parser.add_argument( + "--pipeline_model_parallel_split_rank", + type=int, + required=False, + default=None, + help="If pipeline parallel size > 1, this is the rank at which the encoder ends and the decoder begins.", + ) + parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1)) + parser.add_argument("--cluster_type", required=False, default=None, help="Whether on BCP platform") + parser.add_argument( + "--precision", + type=str, + required=False, + default='bf16-mixed', + choices=['32-true', '16-mixed', 'bf16-mixed'], + help="Precision value for the trainer that matches with precision of the ckpt", + ) + + parser.add_argument( + "--model_type", type=str, required=True, default="gpt", choices=["gpt", "sft", "bert"], + ) + + args = parser.parse_args() + return args + + +def convert(local_rank, rank, world_size, args): + + app_state = AppState() + app_state.data_parallel_rank = 0 + num_nodes = world_size // args.gpus_per_node + + cfg = { + 'trainer': { + 'devices': args.gpus_per_node, + 'num_nodes': num_nodes, + 'accelerator': 'gpu', + 'precision': args.precision, + }, + 'model': { + 'native_amp_init_scale': 2 ** 32, + 'native_amp_growth_interval': 1000, + 'hysteresis': 2, + 'gradient_as_bucket_view': True, + }, + 'cluster_type': args.cluster_type, + } + cfg = OmegaConf.create(cfg) + + # Set precision None after precision plugins are created as PTL >= 2.1 does not allow both + # precision plugins and precision to exist + cfg.trainer.precision = None + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + + app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = args.tensor_model_parallel_size + app_state.pipeline_model_parallel_split_rank = None + + app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size + + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=app_state.tensor_model_parallel_size, + pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + ) + + app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() + app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() + + # check for distributed checkpoint + checkpoint_path = os.path.join(args.checkpoint_folder, args.checkpoint_name) + + logging.info( + f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}' + ) + + if args.model_type == "gpt": + model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) + elif args.model_type == "sft": + model = MegatronGPTSFTModel.load_from_checkpoint( + checkpoint_path, hparams_file=args.hparams_file, trainer=trainer + ) + # we force the target for the loaded model to have the correct target + # because the hparams.yaml sometimes contains MegatronGPTModel as the target. + with open_dict(model.cfg): + model.cfg.target = f"{MegatronGPTSFTModel.__module__}.{MegatronGPTSFTModel.__name__}" + elif args.model_type == 'bert': + model = MegatronBertModel.load_from_checkpoint( + checkpoint_path, hparams_file=args.hparams_file, trainer=trainer + ) + + with open_dict(model.cfg): + model.cfg.torch_distributed_checkpoint = True + + model._save_restore_connector = NLPSaveRestoreConnector() + save_file_path = args.path_to_save + if not args.save_to_nemo: + # With --save_to_nemo, save_to_path is expected to be a directory. + # Adding a dummy model filename here conforms with SaveRestoreConnector's convention. + model._save_restore_connector.pack_nemo_file = False + save_file_path = os.path.join(save_file_path, 'model.nemo') + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + model.save_to(save_file_path) + + logging.info(f'NeMo model saved to: {args.path_to_save}') + + +if __name__ == '__main__': + args = get_args() + + local_rank, rank, world_size = initialize_distributed(args) + + convert(local_rank, rank, world_size, args) diff --git a/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py b/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py new file mode 100644 index 000000000000..f2974aca1642 --- /dev/null +++ b/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py @@ -0,0 +1,212 @@ +#!/usr/bin/env +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert nemo style (fused) lora checkpoint to canonical (unfused) lora checkpoint. +Currently supports TP=PP=1 only. + +Example usage: +python scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py \ + --lora_path nemo_style_lora_model.nemo \ + --output_path ./canonical_style_lora_model.nemo + +""" +import tempfile +from argparse import ArgumentParser +from typing import Dict + +import torch +from omegaconf import OmegaConf, open_dict +from scripts.nlp_language_modeling.merge_lora_weights.merge import replace_number_add_offset + +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector + + +def rename_keys(key): + new_keys = [] + if "lora_kqv_adapter" in key: + new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.q_adapter.")) + new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.k_adapter.")) + new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.v_adapter.")) + elif "lora_hto4h_adapter" in key: + new_keys.append(key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.gate_adapter.")) + new_keys.append(key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.up_adapter.")) + return new_keys + + +def reformat_module_names_to_hf(tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + new_tensors = dict() + for module_name, module_weight in tensors.items(): + # map linear_in and linear_out to lora_a/lora_b counterparts + new_module_name = "base_model." + module_name.replace("linear_in", "lora_A").replace("linear_out", "lora_B") + + # map target modules to their vLLM/HF counterparts + new_module_name = new_module_name.replace("q_adapter", "q_proj") + new_module_name = new_module_name.replace("k_adapter", "k_proj") + new_module_name = new_module_name.replace("v_adapter", "v_proj") + new_module_name = new_module_name.replace("lora_dense_attention_adapter", "o_proj") + new_module_name = new_module_name.replace("lora_4htoh_adapter", "down_proj") + new_module_name = new_module_name.replace("gate_adapter", "gate_proj") + new_module_name = new_module_name.replace("up_adapter", "up_proj") + + # map other parts of the module names to fit vLLM/huggingface + new_module_name = new_module_name.replace(".adapter_layer", "") + new_module_name = new_module_name.replace(".lora_unfused_kqv_proj", "") + new_module_name = new_module_name.replace(".lora_unfused_hto4h_adapter", "") + new_module_name = new_module_name.replace("self_attention", "self_attn") + new_module_name = new_module_name.replace("decoder", "model") + + new_tensors[new_module_name] = module_weight + return new_tensors + + +def convert_hto4h(lora_weights, lora_config): + assert len(lora_weights) == 1, "Only single TP supported for now" + keys_to_update = [] + for key in lora_weights[0].keys(): + if "lora_hto4h_adapter" in key: + keys_to_update.append(key) + + for key in keys_to_update: + if "linear_in" in key: + for new_key in rename_keys(key): + lora_weights[0][new_key] = lora_weights[0][key] + print(new_key, lora_weights[0][new_key].shape) + elif "linear_out" in key: + for idx, new_key in enumerate(rename_keys(key)): + orginal_shape = lora_weights[0][key].shape[0] + lora_weights[0][new_key] = lora_weights[0][key][ + idx * (orginal_shape // 2) : (idx + 1) * (orginal_shape // 2) + ] + print(new_key, lora_weights[0][new_key].shape) + + lora_weights[0].pop(key) + return lora_weights + + +def convert_qkv(lora_weights, lora_model_cfg): + assert len(lora_weights) == 1, "Only single TP supported for now" + if ( + lora_model_cfg.get("num_query_groups", lora_model_cfg.num_attention_heads) + != lora_model_cfg.num_attention_heads + ): + kv_channels = int(lora_model_cfg.hidden_size / lora_model_cfg.num_attention_heads) + kv_size = int(lora_model_cfg.num_query_groups * kv_channels) + else: + kv_size = int(lora_model_cfg.hidden_size) + q_size = lora_model_cfg.hidden_size + k_size, v_size = kv_size, kv_size + + keys_to_update = [] + for key in lora_weights[0].keys(): + if "lora_kqv_adapter" in key: + keys_to_update.append(key) + + for key in keys_to_update: + if "linear_in" in key: + for new_key in rename_keys(key): + lora_weights[0][new_key] = lora_weights[0][key] + print(new_key, lora_weights[0][new_key].shape) + elif "linear_out" in key: + srt = 0 + for new_key, size in zip(rename_keys(key), [q_size, k_size, v_size]): + lora_weights[0][new_key] = lora_weights[0][key][srt : srt + size] + print(new_key, lora_weights[0][new_key].shape) + srt = srt + size + + lora_weights[0].pop(key) + return lora_weights + + +def convert_lora(lora_nemo, save_path, hf_format=False): + with tempfile.TemporaryDirectory() as tmpdir: + NLPSaveRestoreConnector._unpack_nemo_file(lora_nemo, tmpdir) + config_file = f"{tmpdir}/model_config.yaml" + lora_config = OmegaConf.load(config_file) + tp_size = lora_config.tensor_model_parallel_size + pp_size = lora_config.pipeline_model_parallel_size + + lora_state_dict = [{}] * tp_size + + for pp in range(pp_size): + for tp in range(tp_size): + if tp_size == 1: + ckpt_file = f"{tmpdir}/model_weights.ckpt" + elif pp_size == 1: + ckpt_file = f"{tmpdir}/mp_rank_{tp:02d}/model_weights.ckpt" + else: + ckpt_file = f"{tmpdir}/tp_rank_{tp:02d}_pp_rank_{pp:03d}/model_weights.ckpt" + + l = torch.load(ckpt_file, map_location=torch.device('cpu')) + if pp == 0: + lora_state_dict[tp] = l + else: + # calculate layer offset + layer_offset = lora_config.num_layers // pp_size * pp + for key, value in l.items(): + new_key = replace_number_add_offset(key, layer_offset) + lora_state_dict[tp][new_key] = value + + with open_dict(lora_config): + lora_config.peft.lora_tuning.variant = "canonical" + with open(f"{tmpdir}/model_config.yaml", "w") as f: + OmegaConf.save(lora_config, f) + lora_state_dict = convert_qkv(lora_state_dict, lora_config) + lora_state_dict = convert_hto4h(lora_state_dict, lora_config) + # TODO: currently suport tp=1 + lora_state_dict = lora_state_dict[0] + if hf_format: + lora_state_dict = reformat_module_names_to_hf(lora_state_dict) + torch.save(lora_state_dict, f"{save_path}/model_weights_hf_formatted.pt") + else: + torch.save(lora_state_dict, f"{tmpdir}/model_weights.ckpt") + NLPSaveRestoreConnector._make_nemo_file_from_folder(save_path, tmpdir) + + return lora_state_dict, lora_config + + +def fix_for_O2(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if "model.module." not in k: + new_state_dict[k.replace('model.', 'model.module.')] = v + return new_state_dict + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--lora_path", + type=str, + default=None, + required=True, + help="Path to NeMo style (fused) lora checkpoint in .nemo file format", + ) + parser.add_argument( + "--output_path", + type=str, + default=None, + required=True, + help="Path to save the canonical (unfused) lora .nemo file.", + ) + parser.add_argument("--hf_format", action='store_true', help="saves tensors in huggingface naming format.") + parser.add_argument("--precision", type=str, default="16", help="Model precision") + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = get_args() + convert_lora(args.lora_path, args.output_path, args.hf_format) diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index 946acb614f11..a2e39628e4cb 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -809,6 +809,39 @@ def test_list_to_multichannel(self, num_channels, num_targets): # Check the list is converted back to the original signal assert (ASRAudioProcessor.list_to_multichannel(target_list) == golden_target).all() + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2]) + def test_processor_process_audio(self, num_channels): + """Test signal normalization in process_audio. + """ + num_samples = 1000 + num_examples = 30 + + signals = ['input_signal', 'target_signal', 'reference_signal'] + + for normalization_signal in [None] + signals: + # Create processor + processor = ASRAudioProcessor( + sample_rate=16000, random_offset=False, normalization_signal=normalization_signal + ) + + # Generate random signals + for n in range(num_examples): + example = {signal: torch.randn(num_channels, num_samples) for signal in signals} + processed_example = processor.process_audio(example) + + # Expected scale + if normalization_signal: + scale = 1.0 / (example[normalization_signal].abs().max() + processor.eps) + else: + scale = 1.0 + + # Make sure all signals are scaled as expected + for signal in signals: + assert torch.allclose( + processed_example[signal], example[signal] * scale + ), f'Failed example {n} signal {signal}' + @pytest.mark.unit def test_audio_collate_fn(self): """Test `_audio_collate_fn` diff --git a/tests/collections/asr/test_asr_losses.py b/tests/collections/asr/test_asr_losses.py index e09fd71e0892..e050e7cc07c3 100644 --- a/tests/collections/asr/test_asr_losses.py +++ b/tests/collections/asr/test_asr_losses.py @@ -17,7 +17,9 @@ import torch from nemo.collections.asr.losses.audio_losses import ( + MSELoss, SDRLoss, + calculate_mse_batch, calculate_sdr_batch, convolution_invariant_target, scale_invariant_target, @@ -271,7 +273,7 @@ def test_sdr_binary_mask(self, num_channels): estimate = target + noise # Limit calculation to masked samples - mask = _rng.integers(low=0, high=2, size=(batch_size, max_num_samples)) + mask = _rng.integers(low=0, high=2, size=(batch_size, num_channels, max_num_samples)) # Tensors for testing the loss tensor_estimate = torch.tensor(estimate) @@ -282,7 +284,9 @@ def test_sdr_binary_mask(self, num_channels): golden_sdr = 0 for b in range(batch_size): sdr = [ - calculate_sdr_numpy(estimate=estimate[b, m, mask[b, :] > 0], target=target[b, m, mask[b, :] > 0]) + calculate_sdr_numpy( + estimate=estimate[b, m, mask[b, m, :] > 0], target=target[b, m, mask[b, m, :] > 0] + ) for m in range(num_channels) ] sdr = np.mean(np.array(sdr)) @@ -467,3 +471,187 @@ def test_sdr_convolution_invariant(self, num_channels: int, filter_length: int): assert np.allclose( uut_sdr_loss.cpu().detach().numpy(), -golden_sdr, atol=atol ), f'SDRLoss not matching for example {n}' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 4]) + @pytest.mark.parametrize('ndim', [3, 4]) + def test_mse(self, num_channels: int, ndim: int): + """Test SDR calculation + """ + batch_size = 8 + num_samples = 50 + num_features = 123 + num_batches = 10 + random_seed = 42 + atol = 1e-6 + + signal_shape = ( + (batch_size, num_channels, num_features, num_samples) + if ndim == 4 + else (batch_size, num_channels, num_samples) + ) + + reduction_dim = (-2, -1) if ndim == 4 else -1 + + mse_loss = MSELoss(ndim=ndim) + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_batches): + + # Generate random signal + target = _rng.normal(size=signal_shape) + # Random noise + scaling + noise = _rng.uniform(low=0.01, high=1) * _rng.normal(size=signal_shape) + # Estimate + estimate = target + noise + + # DC bias for both + target += _rng.uniform(low=-1, high=1) + estimate += _rng.uniform(low=-1, high=1) + + # Tensors for testing the loss + tensor_estimate = torch.tensor(estimate) + tensor_target = torch.tensor(target) + + # Reference MSE + golden_mse = np.zeros((batch_size, num_channels)) + for b in range(batch_size): + for m in range(num_channels): + err = estimate[b, m, :] - target[b, m, :] + golden_mse[b, m] = np.mean(np.abs(err) ** 2, axis=reduction_dim) + + # Calculate MSE in torch + uut_mse = calculate_mse_batch(estimate=tensor_estimate, target=tensor_target) + + # Calculate MSE loss + uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target) + + # Compare torch SDR vs numpy + assert np.allclose( + uut_mse.cpu().detach().numpy(), golden_mse, atol=atol + ), f'MSE not matching for example {n}' + + # Compare SDR loss vs average of torch SDR + assert np.isclose(uut_mse_loss, uut_mse.mean(), atol=atol), f'MSELoss not matching for example {n}' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 4]) + @pytest.mark.parametrize('ndim', [3, 4]) + def test_mse_weighted(self, num_channels: int, ndim: int): + """Test SDR calculation with weighting for channels + """ + batch_size = 8 + num_samples = 50 + num_features = 123 + num_batches = 10 + random_seed = 42 + atol = 1e-6 + + signal_shape = ( + (batch_size, num_channels, num_features, num_samples) + if ndim == 4 + else (batch_size, num_channels, num_samples) + ) + + reduction_dim = (-2, -1) if ndim == 4 else -1 + + _rng = np.random.default_rng(seed=random_seed) + + channel_weight = _rng.uniform(low=0.01, high=1.0, size=num_channels) + channel_weight = channel_weight / np.sum(channel_weight) + mse_loss = MSELoss(weight=channel_weight, ndim=ndim) + + for n in range(num_batches): + + # Generate random signal + target = _rng.normal(size=signal_shape) + # Random noise + scaling + noise = _rng.uniform(low=0.001, high=10) * _rng.normal(size=target.shape) + # Estimate + estimate = target + noise + + # Tensors for testing the loss + tensor_estimate = torch.tensor(estimate) + tensor_target = torch.tensor(target) + + # Reference MSE + golden_mse = 0 + for b in range(batch_size): + mse = [ + np.mean(np.abs(estimate[b, m, :] - target[b, m, :]) ** 2, axis=reduction_dim) + for m in range(num_channels) + ] + # weighted sum + mse = np.sum(np.array(mse) * channel_weight) + golden_mse += mse + golden_mse /= batch_size # average over batch + + # Calculate MSE loss + uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target) + + # Compare + assert np.allclose( + uut_mse_loss.cpu().detach().numpy(), golden_mse, atol=atol + ), f'MSELoss not matching for example {n}' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 4]) + @pytest.mark.parametrize('ndim', [3, 4]) + def test_mse_input_length(self, num_channels: int, ndim: int): + """Test SDR calculation with input length. + """ + batch_size = 8 + max_num_samples = 50 + num_features = 123 + num_batches = 10 + random_seed = 42 + atol = 1e-6 + + signal_shape = ( + (batch_size, num_channels, num_features, max_num_samples) + if ndim == 4 + else (batch_size, num_channels, max_num_samples) + ) + + reduction_dim = (-2, -1) if ndim == 4 else -1 + + _rng = np.random.default_rng(seed=random_seed) + + mse_loss = MSELoss(ndim=ndim) + + for n in range(num_batches): + + # Generate random signal + target = _rng.normal(size=signal_shape) + # Random noise + scaling + noise = _rng.uniform(low=0.001, high=10) * _rng.normal(size=target.shape) + # Estimate + estimate = target + noise + + # Limit calculation to random input_length samples + input_length = _rng.integers(low=1, high=max_num_samples, size=batch_size) + + # Tensors for testing the loss + tensor_estimate = torch.tensor(estimate) + tensor_target = torch.tensor(target) + tensor_input_length = torch.tensor(input_length) + + # Reference MSE + golden_mse = 0 + for b, b_len in enumerate(input_length): + mse = [ + np.mean(np.abs(estimate[b, m, ..., :b_len] - target[b, m, ..., :b_len]) ** 2, axis=reduction_dim) + for m in range(num_channels) + ] + mse = np.mean(np.array(mse)) + golden_mse += mse + golden_mse /= batch_size # average over batch + + # Calculate MSE + uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target, input_length=tensor_input_length) + + # Compare + assert np.allclose( + uut_mse_loss.cpu().detach().numpy(), golden_mse, atol=atol + ), f'MSELoss not matching for example {n}' diff --git a/tests/collections/asr/test_audio_preprocessing.py b/tests/collections/asr/test_audio_preprocessing.py index b0875936a7f7..600b9fed44fa 100644 --- a/tests/collections/asr/test_audio_preprocessing.py +++ b/tests/collections/asr/test_audio_preprocessing.py @@ -155,7 +155,11 @@ def test_spec_to_audio(self, fft_length: int, num_channels: int): @pytest.mark.skipif(not HAVE_TORCHAUDIO, reason="Modules in this test require torchaudio") @pytest.mark.parametrize('fft_length', [128, 1024]) @pytest.mark.parametrize('num_channels', [1, 4]) - def test_audio_to_spectrogram_reconstruction(self, fft_length: int, num_channels: int): + @pytest.mark.parametrize('magnitude_power', [0.5, 1, 2]) + @pytest.mark.parametrize('scale', [0.1, 1.0]) + def test_audio_to_spectrogram_reconstruction( + self, fft_length: int, num_channels: int, magnitude_power: float, scale: float + ): """Test analysis and synthesis transform result in a perfect reconstruction. """ batch_size = 4 @@ -169,8 +173,12 @@ def test_audio_to_spectrogram_reconstruction(self, fft_length: int, num_channels hop_lengths = [fft_length // 2, fft_length // 4] for hop_length in hop_lengths: - audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) - spec2audio = SpectrogramToAudio(fft_length=fft_length, hop_length=hop_length) + audio2spec = AudioToSpectrogram( + fft_length=fft_length, hop_length=hop_length, magnitude_power=magnitude_power, scale=scale + ) + spec2audio = SpectrogramToAudio( + fft_length=fft_length, hop_length=hop_length, magnitude_power=magnitude_power, scale=scale + ) for n in range(num_examples): x = _rng.normal(size=(batch_size, num_channels, num_samples)) diff --git a/tests/setup/__main__.py b/tests/setup/__main__.py index 289a2537e2f2..a08ccdaa1634 100644 --- a/tests/setup/__main__.py +++ b/tests/setup/__main__.py @@ -34,8 +34,8 @@ ) create_hf_model( - model_name_or_path="/home/TestData/nlp/meta-llama/Llama-2-7b-hf", - output_dir=os.path.join(args.save_dir, "megatron_llama/llama-ci-hf"), + model_name_or_path="/home/TestData/nlp/megatron_llama/llama-ci-hf", + output_dir=os.path.join(args.save_dir, "megatron_llama/llama-ci-hf-tiny"), config_updates={"hidden_size": 256, "num_attention_heads": 4, "num_hidden_layers": 2, "num_key_value_heads": 4}, overwrite=args.overwrite, )