diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml
index de250596da62..6f090bd34213 100644
--- a/.github/workflows/cicd-main.yml
+++ b/.github/workflows/cicd-main.yml
@@ -132,6 +132,9 @@ jobs:
apt-get update && apt-get install libsox-fmt-all -y && \
popd
+ # AMMO installation
+ pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir
+
# PyTorch Lightning version
python -c "import pytorch_lightning; print(pytorch_lightning.__version__)"
@@ -220,7 +223,26 @@ jobs:
- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
if: "failure()"
-
+ L0_Setup_Test_Data_And_Models:
+ needs: [cicd-test-container-setup]
+ runs-on: self-hosted-azure
+ container:
+ image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
+ options:
+ # --user 0:128
+ --device=/dev/nvidia0
+ --gpus all
+ --shm-size=8g
+ --env TRANSFORMERS_OFFLINE=0
+ --env HYDRA_FULL_ERROR=1
+ --volume /mnt/datadrive/TestData:/home/TestData
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - run: |
+ python -m tests.setup --save_dir /home/TestData/nlp
+ - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
+ if: "failure()"
## - name: L2: Multimodal Imagen Train
@@ -243,10 +265,9 @@ jobs:
uses: actions/checkout@v4
- run: |
CUDA_VISIBLE_DEVICES=0 python scripts/checkpoint_converters/convert_llama_hf_to_nemo.py \
- --input_name_or_path=/home/TestData/nlp/megatron_llama/llama-ci-hf \
- --output_path=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \
+ --input_name_or_path=/home/TestData/nlp/megatron_llama/llama-ci-hf-tiny \
+ --output_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \
--precision=16
- rm -f /home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo
- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
if: "failure()"
@@ -322,6 +343,124 @@ jobs:
- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
if: "failure()"
+ L2_PTQ_Llama2_Export_Only:
+ needs: [cicd-test-container-setup]
+ runs-on: self-hosted-azure
+ container:
+ image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
+ options:
+ # --user 0:128
+ --device=/dev/nvidia0
+ --gpus all
+ --shm-size=8g
+ --env TRANSFORMERS_OFFLINE=0
+ --env HYDRA_FULL_ERROR=1
+ --volume /mnt/datadrive/TestData:/home/TestData
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - run: |
+ python examples/nlp/language_modeling/megatron_llama_quantization.py \
+ model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \
+ quantization.algorithm=null \
+ model_save=/home/TestData/nlp/megatron_llama/ci_baseline
+
+ rm -rf /home/TestData/nlp/megatron_llama/ci_baseline
+ - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
+ if: "failure()"
+
+ L2_PTQ_Llama2_FP8:
+ needs: [cicd-test-container-setup]
+ runs-on: self-hosted-azure
+ container:
+ image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
+ options:
+ # --user 0:128
+ --device=/dev/nvidia0
+ --gpus all
+ --shm-size=8g
+ --env TRANSFORMERS_OFFLINE=0
+ --env HYDRA_FULL_ERROR=1
+ --volume /mnt/datadrive/TestData:/home/TestData
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - run: |
+ python examples/nlp/language_modeling/megatron_llama_quantization.py \
+ model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \
+ tensor_model_parallel_size=2 \
+ trainer.devices=2 \
+ quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \
+ quantization.algorithm=fp8 \
+ quantization.num_calib_size=8 \
+ inference.batch_size=2 \
+ export.inference_tensor_parallel=2 \
+ model_save=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo
+
+ rm -rf /home/TestData/nlp/megatron_llama/ci_fp8.qnemo
+ - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
+ if: "failure()"
+
+ L2_PTQ_Llama2_INT8_SQ:
+ needs: [cicd-test-container-setup]
+ runs-on: self-hosted-azure
+ container:
+ image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
+ options:
+ # --user 0:128
+ --device=/dev/nvidia0
+ --gpus all
+ --shm-size=8g
+ --env TRANSFORMERS_OFFLINE=0
+ --env HYDRA_FULL_ERROR=1
+ --volume /mnt/datadrive/TestData:/home/TestData
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - run: |
+ python examples/nlp/language_modeling/megatron_llama_quantization.py \
+ model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \
+ quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \
+ quantization.algorithm=int8_sq \
+ quantization.num_calib_size=8 \
+ inference.batch_size=2 \
+ model_save=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo
+
+ rm -rf /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo
+ - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
+ if: "failure()"
+
+ L2_PTQ_Llama2_INT4_AWQ:
+ needs: [cicd-test-container-setup]
+ runs-on: self-hosted-azure
+ container:
+ image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
+ options:
+ # --user 0:128
+ --device=/dev/nvidia0
+ --gpus all
+ --shm-size=8g
+ --env TRANSFORMERS_OFFLINE=0
+ --env HYDRA_FULL_ERROR=1
+ --volume /mnt/datadrive/TestData:/home/TestData
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - run: |
+ python examples/nlp/language_modeling/megatron_llama_quantization.py \
+ model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \
+ tensor_model_parallel_size=1 \
+ trainer.devices=1 \
+ quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \
+ quantization.algorithm=int4_awq \
+ quantization.num_calib_size=8 \
+ inference.batch_size=2 \
+ model_save=/home/TestData/nlp/megatron_llama/ci_int4_awq.qnemo
+
+ rm -rf /home/TestData/nlp/megatron_llama/ci_int4_awq.qnemo
+ - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
+ if: "failure()"
+
# L2: ASR dev run
ASR_dev_run_Speech_to_Text:
needs: [cicd-test-container-setup]
@@ -4664,7 +4803,7 @@ jobs:
--volume /mnt/datadrive/TestData:/home/TestData
steps:
- name: Checkout repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- run: |
rm -rf /home/TestData/nlp/megatron_ir/working_dir
diff --git a/README.rst b/README.rst
index 66b3a5806c2d..0b05bd0390f8 100644
--- a/README.rst
+++ b/README.rst
@@ -77,6 +77,31 @@ Latest News
+
+ Speech Recognition
+
+ New Standard for Speech Recognition and Translation from the NVIDIA NeMo Canary Model (2024/04/18)
+
+ The NeMo team just released Canary, a multilingual model that transcribes speech in English, Spanish, German, and French with punctuation and capitalization. Canary also provides bi-directional translation, between English and the three other supported languages.
+
+
+
+
+ Pushing the Boundaries of Speech Recognition with NVIDIA NeMo Parakeet ASR Models (2024/04/18)
+
+ NVIDIA NeMo, an end-to-end platform for the development of multimodal generative AI models at scale anywhere—on any cloud and on-premises—released the Parakeet family of automatic speech recognition (ASR) models. These state-of-the-art ASR models, developed in collaboration with Suno.ai, transcribe spoken English with exceptional accuracy.
+
+
+
+
+ Turbocharge ASR Accuracy and Speed with NVIDIA NeMo Parakeet-TDT (2024/04/18)
+
+ NVIDIA NeMo, an end-to-end platform for developing multimodal generative AI models at scale anywhere—on any cloud and on-premises—recently released Parakeet-TDT. This new addition to the NeMo ASR Parakeet model family boasts better accuracy and 64% greater speed over the previously best model, Parakeet-RNNT-1.1B.
+
+
+
+
+
diff --git a/examples/audio_tasks/audio_to_audio_eval.py b/examples/audio_tasks/audio_to_audio_eval.py
index 4ac68dfc84e7..ab6623df298d 100644
--- a/examples/audio_tasks/audio_to_audio_eval.py
+++ b/examples/audio_tasks/audio_to_audio_eval.py
@@ -61,6 +61,7 @@
import json
import os
import tempfile
+from collections import defaultdict
from dataclasses import dataclass, field, is_dataclass
from typing import List, Optional
@@ -101,6 +102,9 @@ class AudioEvaluationConfig(process_audio.ProcessConfig):
# Metrics to calculate
metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi'])
+ # Return metric values for each example
+ return_values_per_example: bool = False
+
def get_evaluation_dataloader(config):
"""Prepare a dataloader for evaluation.
@@ -174,6 +178,9 @@ def main(cfg: AudioEvaluationConfig):
# Setup metrics
metrics = get_metrics(cfg)
+ if cfg.return_values_per_example and cfg.batch_size > 1:
+ raise ValueError('return_example_values is only supported for batch_size=1.')
+
# Processing
if not cfg.only_score_manifest:
# Process audio using the configured model and save in the output directory
@@ -236,6 +243,10 @@ def main(cfg: AudioEvaluationConfig):
num_files += 1
+ if cfg.max_utts is not None and num_files >= cfg.max_utts:
+ logging.info('Reached max_utts: %s', cfg.max_utts)
+ break
+
# Prepare dataloader
config = {
'manifest_filepath': temporary_manifest_filepath,
@@ -249,6 +260,8 @@ def main(cfg: AudioEvaluationConfig):
}
temporary_dataloader = get_evaluation_dataloader(config)
+ metrics_value_per_example = defaultdict(list)
+
# Calculate metrics
for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'):
processed_signal, processed_length, target_signal, target_length = eval_batch
@@ -257,7 +270,9 @@ def main(cfg: AudioEvaluationConfig):
raise RuntimeError(f'Length mismatch.')
for name, metric in metrics.items():
- metric.update(preds=processed_signal, target=target_signal, input_length=target_length)
+ value = metric(preds=processed_signal, target=target_signal, input_length=target_length)
+ if cfg.return_values_per_example:
+ metrics_value_per_example[name].append(value.item())
# Convert to a dictionary with name: value
metrics_value = {name: metric.compute().item() for name, metric in metrics.items()}
@@ -277,6 +292,7 @@ def main(cfg: AudioEvaluationConfig):
# Inject the metric name and score into the config, and return the entire config
with open_dict(cfg):
cfg.metrics_value = metrics_value
+ cfg.metrics_value_per_example = dict(metrics_value_per_example)
return cfg
diff --git a/examples/audio_tasks/conf/beamforming.yaml b/examples/audio_tasks/conf/beamforming.yaml
index 18e04f0bd12a..3abc4f134e64 100644
--- a/examples/audio_tasks/conf/beamforming.yaml
+++ b/examples/audio_tasks/conf/beamforming.yaml
@@ -44,7 +44,6 @@ model:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram
- power: null
decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
diff --git a/examples/audio_tasks/conf/masking.yaml b/examples/audio_tasks/conf/masking.yaml
index c667bec53076..68adca116aa5 100644
--- a/examples/audio_tasks/conf/masking.yaml
+++ b/examples/audio_tasks/conf/masking.yaml
@@ -1,5 +1,3 @@
-# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
-#
name: "masking"
model:
@@ -44,7 +42,6 @@ model:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram
- power: null
decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
diff --git a/examples/audio_tasks/conf/predictive.yaml b/examples/audio_tasks/conf/predictive.yaml
new file mode 100644
index 000000000000..b141ba6fd1ee
--- /dev/null
+++ b/examples/audio_tasks/conf/predictive.yaml
@@ -0,0 +1,130 @@
+name: "predictive_model"
+
+model:
+ type: predictive
+ sample_rate: 16000
+ skip_nan_grad: false
+ num_outputs: 1
+ normalize_input: true # normalize the input signal to 0dBFS
+
+ train_ds:
+ manifest_filepath: ???
+ input_key: noisy_filepath
+ target_key: clean_filepath
+ audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256
+ random_offset: true
+ normalization_signal: input_signal
+ batch_size: 8 # batch size may be increased based on the available memory
+ shuffle: true
+ num_workers: 8
+ pin_memory: true
+
+ validation_ds:
+ manifest_filepath: ???
+ input_key: noisy_filepath
+ target_key: clean_filepath
+ batch_size: 8
+ shuffle: false
+ num_workers: 4
+ pin_memory: true
+
+ encoder:
+ _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
+ fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
+ hop_length: 128
+ magnitude_power: 0.5
+ scale: 0.33
+
+ decoder:
+ _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
+ fft_length: ${model.encoder.fft_length}
+ hop_length: ${model.encoder.hop_length}
+ magnitude_power: ${model.encoder.magnitude_power}
+ scale: ${model.encoder.scale}
+
+ estimator:
+ _target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus
+ in_channels: 1 # single-channel noisy input
+ out_channels: 1 # single-channel estimate
+ num_res_blocks: 3 # increased number of res blocks
+ pad_time_to: 64 # pad to 64 frames for the time dimension
+ pad_dimension_to: 0 # no padding in the frequency dimension
+
+ loss:
+ _target_: nemo.collections.asr.losses.MSELoss # computed in the time domain
+
+ metrics:
+ val:
+ sisdr: # output SI-SDR
+ _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
+
+ optim:
+ name: adam
+ lr: 1e-4
+ # optimizer arguments
+ betas: [0.9, 0.999]
+ weight_decay: 0.0
+
+trainer:
+ devices: -1 # number of GPUs, -1 would use all available GPUs
+ num_nodes: 1
+ max_epochs: -1
+ max_steps: -1 # computed at runtime if not set
+ val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
+ accelerator: auto
+ strategy: ddp
+ accumulate_grad_batches: 1
+ gradient_clip_val: null
+ precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
+ log_every_n_steps: 25 # Interval of logging.
+ enable_progress_bar: true
+ num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
+ check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
+ sync_batchnorm: true
+ enable_checkpointing: false # Provided by exp_manager
+ logger: false # Provided by exp_manager
+
+exp_manager:
+ exp_dir: null
+ name: ${name}
+
+ # use exponential moving average for model parameters
+ ema:
+ enable: true
+ decay: 0.999 # decay rate
+ cpu_offload: false # offload EMA parameters to CPU to save GPU memory
+ every_n_steps: 1 # how often to update EMA weights
+ validate_original_weights: False # use original weights for validation calculation?
+
+ # logging
+ create_tensorboard_logger: true
+
+ # checkpointing
+ create_checkpoint_callback: true
+ checkpoint_callback_params:
+ # in case of multiple validation sets, first one is used
+ monitor: val_sisdr
+ mode: max
+ save_top_k: 5
+ always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints
+
+ # early stopping
+ create_early_stopping_callback: true
+ early_stopping_callback_params:
+ monitor: val_sisdr
+ mode: max
+ min_delta: 0.0
+ patience: 20 # patience in terms of check_val_every_n_epoch
+ verbose: true
+ strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.
+
+ resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
+ # you need to set these two to true to continue the training
+ resume_if_exists: false
+ resume_ignore_no_checkpoint: false
+
+ # You may use this section to create a W&B logger
+ create_wandb_logger: false
+ wandb_logger_kwargs:
+ name: null
+ project: null
diff --git a/examples/audio_tasks/conf/score_based_generative.yaml b/examples/audio_tasks/conf/score_based_generative.yaml
new file mode 100644
index 000000000000..c0b36bd750a2
--- /dev/null
+++ b/examples/audio_tasks/conf/score_based_generative.yaml
@@ -0,0 +1,149 @@
+name: score_based_generative_model
+
+model:
+ type: score_based
+ sample_rate: 16000
+ skip_nan_grad: false
+ num_outputs: 1
+ normalize_input: true
+ max_utts_evaluation_metrics: 50 # metric calculation needs full inference and is slow, so we limit to first few files
+
+ train_ds:
+ manifest_filepath: ???
+ input_key: noisy_filepath
+ target_key: clean_filepath
+ audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256
+ random_offset: true
+ normalization_signal: input_signal
+ batch_size: 8 # batch size may be increased based on the available memory
+ shuffle: true
+ num_workers: 8
+ pin_memory: true
+
+ validation_ds:
+ manifest_filepath: ???
+ input_key: noisy_filepath
+ target_key: clean_filepath
+ normalize_input: false # load data as is for validation, the model will normalize it for inference
+ batch_size: 4
+ shuffle: false
+ num_workers: 4
+ pin_memory: true
+
+ encoder:
+ _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
+ fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
+ hop_length: 128
+ magnitude_power: 0.5
+ scale: 0.33
+
+ decoder:
+ _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
+ fft_length: ${model.encoder.fft_length}
+ hop_length: ${model.encoder.hop_length}
+ magnitude_power: ${model.encoder.magnitude_power}
+ scale: ${model.encoder.scale}
+
+ estimator:
+ _target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus
+ in_channels: 2 # concatenation of single-channel perturbed and noisy
+ out_channels: 1 # single-channel score estimate
+ conditioned_on_time: true
+ num_res_blocks: 3 # increased number of res blocks
+ pad_time_to: 64 # pad to 64 frames for the time dimension
+ pad_dimension_to: 0 # no padding in the frequency dimension
+
+ sde:
+ _target_: nemo.collections.asr.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE
+ stiffness: 1.5
+ std_min: 0.05
+ std_max: 0.5
+ num_steps: 1000
+
+ sampler:
+ _target_: nemo.collections.asr.parts.submodules.diffusion.PredictorCorrectorSampler
+ predictor: reverse_diffusion
+ corrector: annealed_langevin_dynamics
+ num_steps: 50
+ num_corrector_steps: 1
+ snr: 0.5
+
+ loss:
+ _target_: nemo.collections.asr.losses.MSELoss
+ ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time)
+
+ metrics:
+ val:
+ sisdr: # output SI-SDR
+ _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
+
+ optim:
+ name: adam
+ lr: 1e-4
+ # optimizer arguments
+ betas: [0.9, 0.999]
+ weight_decay: 0.0
+
+trainer:
+ devices: -1 # number of GPUs, -1 would use all available GPUs
+ num_nodes: 1
+ max_epochs: -1
+ max_steps: -1 # computed at runtime if not set
+ val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
+ accelerator: auto
+ strategy: ddp
+ accumulate_grad_batches: 1
+ gradient_clip_val: null
+ precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
+ log_every_n_steps: 25 # Interval of logging.
+ enable_progress_bar: true
+ num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
+ check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
+ sync_batchnorm: true
+ enable_checkpointing: false # Provided by exp_manager
+ logger: false # Provided by exp_manager
+
+exp_manager:
+ exp_dir: null
+ name: ${name}
+
+ # use exponential moving average for model parameters
+ ema:
+ enable: true
+ decay: 0.999 # decay rate
+ cpu_offload: false # offload EMA parameters to CPU to save GPU memory
+ every_n_steps: 1 # how often to update EMA weights
+ validate_original_weights: false # use original weights for validation calculation?
+
+ # logging
+ create_tensorboard_logger: true
+
+ # checkpointing
+ create_checkpoint_callback: true
+ checkpoint_callback_params:
+ # in case of multiple validation sets, first one is used
+ monitor: val_sisdr
+ mode: max
+ save_top_k: 5
+ always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints
+
+ # early stopping
+ create_early_stopping_callback: true
+ early_stopping_callback_params:
+ monitor: val_sisdr
+ mode: max
+ min_delta: 0.0
+ patience: 20 # patience in terms of check_val_every_n_epoch
+ verbose: true
+ strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.
+
+ resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
+ # you need to set these two to true to continue the training
+ resume_if_exists: false
+ resume_ignore_no_checkpoint: false
+
+ # You may use this section to create a W&B logger
+ create_wandb_logger: false
+ wandb_logger_kwargs:
+ name: null
+ project: null
diff --git a/examples/audio_tasks/speech_enhancement.py b/examples/audio_tasks/speech_enhancement.py
index 250d212d2a25..33a25c1c107c 100644
--- a/examples/audio_tasks/speech_enhancement.py
+++ b/examples/audio_tasks/speech_enhancement.py
@@ -26,25 +26,64 @@
PyTorch Lightning Trainer arguments and args of the model and the optimizer can be added or overriden from CLI
"""
+from enum import Enum
+
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
-from nemo.collections.asr.models import EncMaskDecAudioToAudioModel
+from nemo.collections.asr.models.enhancement_models import (
+ EncMaskDecAudioToAudioModel,
+ PredictiveAudioToAudioModel,
+ ScoreBasedGenerativeAudioToAudioModel,
+)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
+class ModelType(str, Enum):
+ """Enumeration with the available model types.
+ """
+
+ MaskBased = 'mask_based'
+ Predictive = 'predictive'
+ ScoreBased = 'score_based'
+
+
+def get_model_class(model_type: ModelType):
+ """Get model class for a given model type.
+ """
+ if model_type == ModelType.MaskBased:
+ return EncMaskDecAudioToAudioModel
+ elif model_type == ModelType.Predictive:
+ return PredictiveAudioToAudioModel
+ elif model_type == ModelType.ScoreBased:
+ return ScoreBasedGenerativeAudioToAudioModel
+ else:
+ raise ValueError(f'Unknown model type: {model_type}')
+
+
@hydra_runner(config_path="./conf", config_name="masking")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg, resolve=True)}')
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
- model = EncMaskDecAudioToAudioModel(cfg=cfg.model, trainer=trainer)
- # Initialize the weights of the model from another model, if provided via config
+ # Get model class
+ model_type = cfg.model.get('type')
+ if model_type is None:
+ model_type = ModelType.MaskBased
+ logging.warning('model_type not found in config. Using default: %s', model_type)
+
+ logging.info('Get class for model type: %s', model_type)
+ model_class = get_model_class(model_type)
+
+ logging.info('Instantiate model %s', model_class.__name__)
+ model = model_class(cfg=cfg.model, trainer=trainer)
+
+ logging.info('Initialize the weights of the model from another model, if provided via config')
model.maybe_init_from_pretrained_checkpoint(cfg)
# Train the model
diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml
index 8ce009d5458f..dff963590864 100644
--- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml
+++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml
@@ -49,8 +49,8 @@ model:
precision: ${trainer.precision}
# specify micro_batch_size, global_batch_size, and model parallelism
# gradient accumulation will be done automatically based on data_parallel_size
- micro_batch_size: 1 # limited by GPU memory
- global_batch_size: 1 # will use more micro batches to reach global batch size
+ micro_batch_size: 16 # limited by GPU memory
+ global_batch_size: 16 # will use more micro batches to reach global batch size
native_amp_init_scale: 65536.0 # Init scale for grad scaler used at fp16
@@ -97,15 +97,15 @@ model:
unet_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel
from_pretrained: #/ckpts/nemo-v1-2.ckpt
- from_NeMo: True #Must be specified when from pretrained is not None, False means loading unet from HF ckpt
+ from_NeMo: False #Must be specified when from pretrained is not None, False means loading unet from HF ckpt
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- - 4
- - 2
- - 1
+ - 4
+ - 2
+ - 1
num_res_blocks: 2
channel_mult:
- 1
@@ -121,6 +121,7 @@ model:
use_flash_attention: True
unet_precision: fp32
resblock_gn_groups: 32
+ use_te_fp8: False
first_stage_config:
_target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL
@@ -140,22 +141,22 @@ model:
- 4
- 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
- _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder
- restore_from_path: /ckpts/openai.nemo
+ _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
+ version: openai/clip-vit-large-patch14
device: cuda
- freeze: True
- layer: "last"
- # For compatibility of history version that uses HF clip model
- # _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
- # version: openai/clip-vit-large-patch14
- # device: cuda
- # max_length: 77
+ max_length: 77
+ # _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder
+ # restore_from_path: /ckpts/openai-old.nemo
+ # device: cuda
+ # freeze: True
+ # layer: "last"
+
# miscellaneous
@@ -163,7 +164,7 @@ model:
resume_from_checkpoint: null # manually set the checkpoint file to load from
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
- ddp_overlap: True # True for using PyTorch DDP overlap.
+ ddp_overlap: False # True for using PyTorch DDP overlap.
optim:
name: fused_adam
@@ -191,7 +192,7 @@ model:
synthetic_data_length: 10000
train:
dataset_path:
- - /datasets/coyo/test.pkl
+ - /datasets/coyo/wdinfo/coyo-700m/wdinfo-selene.pkl
augmentations:
resize_smallest_side: 512
center_crop_h_w: 512, 512
diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py
index f1e5e2872ea7..58e9e6e64470 100644
--- a/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py
+++ b/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py
@@ -28,6 +28,9 @@ def model_cfg_modifier(model_cfg):
model_cfg.unet_config.use_flash_attention = False
model_cfg.unet_config.from_pretrained = None
model_cfg.first_stage_config.from_pretrained = None
+ model_cfg.first_stage_config._target_ = (
+ 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL'
+ )
torch.backends.cuda.matmul.allow_tf32 = True
trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference(
diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml
new file mode 100644
index 000000000000..b37d64a325e5
--- /dev/null
+++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml
@@ -0,0 +1,204 @@
+# An example model that works with this config is "https://huggingface.co/yuvalkirstain/PickScore_v1"
+model:
+ precision: 32
+ # specify micro_batch_size, global_batch_size, and model parallelism
+ # gradient accumulation will be done automatically based on data_parallel_size
+ micro_batch_size: 32 # limited by GPU memory
+ global_batch_size: 32 # will use more micro batches to reach global batch size
+ tensor_model_parallel_size: 1 # intra-layer model parallelism
+ pipeline_model_parallel_size: 1 # inter-layer model parallelism
+ virtual_pipeline_model_parallel_size: null # interleaved pipeline
+
+ restore_from_pretrained: null # used in fine-tuning
+ # multimodal configs
+ output_dim: 1024
+ # As the number of devices used to train increases, so does the space complexity of
+ # the logit matrix. Using a naïve all-gather scheme, space complexity will be
+ # `O(n^2)`. Instead, complexity may become effectively linear if the flags
+ # `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one
+ # numerical results as the naïve method.
+ local_loss: False # calculate loss w/ local features @ global (instead of realizing full global @ global matrix)
+ gather_with_grad: True # enable full distributed gradient for feature gather, set this to False may cause convergence issue
+
+ vision:
+ precision: 32
+ # vision configs
+ patch_dim: 14
+ img_h: 224
+ img_w: 224
+ image_mean: null
+ image_std: null
+ num_channels: 3
+ drop_patch_rate: 0.0
+ drop_path_rate: 0.0
+ global_average_pool: False
+ output_dim: ${model.output_dim}
+ class_token_length: 1
+ preprocess_layernorm: True # apply layer norm to embedded tokens
+
+ # model architecture
+ encoder_seq_length: 196
+ max_position_embeddings: ${.encoder_seq_length}
+ position_embedding_type: learned_parameters
+ num_layers: 32
+ hidden_size: 1280
+ ffn_hidden_size: 5120 # Transformer FFN hidden size. Usually 4 * hidden_size.
+ num_attention_heads: 16
+ init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
+ use_scaled_init_method: True # use scaled residuals initialization
+ hidden_dropout: 0. # Dropout probability for hidden state transformer.
+ attention_dropout: 0.
+ kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
+ apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
+ normalization: layernorm # Type of normalization layers
+ layernorm_epsilon: 1e-5
+ do_layer_norm_weight_decay: False # True means weight decay on all params
+ pre_process: True # add embedding
+ post_process: True # add pooler
+ persist_layer_norm: True # Use of persistent fused layer norm kernel.
+
+ ## Activation Checkpointing
+ activations_checkpoint_granularity: null # 'selective' or 'full'
+ activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
+ activations_checkpoint_num_layers: null # not used with 'selective'
+ sequence_parallel: False
+
+ # precision
+ native_amp_init_scale: 4294967296 # 2 ** 32
+ native_amp_growth_interval: 1000
+ hysteresis: 2 # Gradient scale hysteresis
+ fp32_residual_connection: False # Move residual connections to fp32
+ fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16
+
+ # model fusions
+ masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
+ bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
+
+ use_cpu_initialization: False # Init weights on the CPU (slow for large models)
+ onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
+ gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism.
+ openai_gelu: False
+ bias_activation_fusion: False
+ megatron_legacy: True
+ activation: gelu
+
+
+
+ text:
+ precision: 32
+ # text configs
+ output_dim: ${model.output_dim}
+
+ # model architecture
+ encoder_seq_length: 77
+ max_position_embeddings: ${.encoder_seq_length}
+ position_embedding_type: learned_parameters
+ num_layers: 24
+ hidden_size: 1024
+ ffn_hidden_size: 4096 # Transformer FFN hidden size. Usually 4 * hidden_size.
+ num_attention_heads: 16
+ init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
+ use_scaled_init_method: True # use scaled residuals initialization
+ hidden_dropout: 0. # Dropout probability for hidden state transformer.
+ attention_dropout: 0.
+ kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
+ apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
+ normalization: layernorm # Type of normalization layers
+ layernorm_epsilon: 1e-5
+ do_layer_norm_weight_decay: False # True means weight decay on all params
+ pre_process: True # add embedding
+ post_process: True # add pooler
+ persist_layer_norm: True # Use of persistent fused layer norm kernel.
+
+ ## Activation Checkpointing
+ activations_checkpoint_granularity: null # 'selective' or 'full'
+ activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
+ activations_checkpoint_num_layers: null # not used with 'selective'
+ num_micro_batches_with_partial_activation_checkpoints: null
+ activations_checkpoint_layers_per_pipeline: null
+ sequence_parallel: False
+
+ # precision
+ native_amp_init_scale: 4294967296 # 2 ** 32
+ native_amp_growth_interval: 1000
+ hysteresis: 2 # Gradient scale hysteresis
+ fp32_residual_connection: False # Move residual connections to fp32
+ fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16
+
+ # model fusions
+ masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
+ bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
+
+ use_cpu_initialization: False # Init weights on the CPU (slow for large models)
+ onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
+ gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism.
+ openai_gelu: False
+ bias_activation_fusion: False
+ megatron_legacy: True
+
+ transformer_engine: False
+ fp8: False # enables fp8 in TransformerLayer forward
+ fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3
+ fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID
+ fp8_margin: 0 # scaling margin
+ fp8_interval: 1 # scaling update interval
+ fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor
+ fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history
+ use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False.
+ activation: gelu
+
+ # Megatron O2-style half-precision
+ megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
+ grad_allreduce_chunk_size_mb: 125
+ grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce
+
+ # miscellaneous
+ seed: 1234
+ resume_from_checkpoint: null # manually set the checkpoint file to load from
+ apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
+ gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
+
+ tokenizer:
+ library: 'huggingface'
+ type: 'openai/clip-vit-large-patch14'
+ model: null
+ vocab_file: null
+ merge_file: null
+ delimiter: null # only used for tabular tokenizer
+ sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers.
+ make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
+
+ data:
+ num_workers: 8
+ train:
+ dataset_path: # List of paths to pkl files or tar files
+ - /datasets/coyo/test.pkl
+ validation: # List of paths to pkl files or tar files
+ dataset_path:
+ - /datasets/coyo/test.pkl
+ webdataset:
+ infinite_sampler: False
+ local_root_path: /datasets/coyo
+
+ imagenet_val: null # Path to imagenet val set for conducting zero shot evaluation.
+
+ # Nsys profiling options
+ nsys_profile:
+ enabled: False
+ start_step: 10 # Global batch to start profiling
+ end_step: 10 # Global batch to end profiling
+ ranks: [ 0 ] # Global rank IDs to profile
+ gen_shape: False # Generate model and kernel details including input shapes
+
+ optim:
+ name: fused_adam
+ lr: 1e-3
+ weight_decay: 0.2
+ betas:
+ - 0.9
+ - 0.98
+ sched:
+ name: CosineAnnealing
+ warmup_steps: 2000
+ constant_steps: 0
+ min_lr: 1e-5
\ No newline at end of file
diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml
index 40347f317fbb..6517b62010b4 100644
--- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml
+++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml
@@ -101,6 +101,7 @@ model:
position_embedding_strategy: null # used only when weight_tying is True
lora_tuning:
+ variant: "nemo" # can be "nemo" or "canonical"
target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
adapter_dim: 32
alpha: ${model.peft.lora_tuning.adapter_dim}
diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml
index 67d43eb303f4..592eed6c4420 100644
--- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml
+++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml
@@ -89,6 +89,7 @@ model:
position_embedding_strategy: null # used only when weight_tying is True
lora_tuning:
+ variant: "nemo" # can be either "canonical" or "nemo"
target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
adapter_dim: 32
adapter_dropout: 0.0
diff --git a/nemo/collections/asr/data/audio_to_audio.py b/nemo/collections/asr/data/audio_to_audio.py
index a3c6dd0cc1b3..4f4727239a4b 100644
--- a/nemo/collections/asr/data/audio_to_audio.py
+++ b/nemo/collections/asr/data/audio_to_audio.py
@@ -130,13 +130,19 @@ class ASRAudioProcessor:
sample_rate: sample rate used for all audio signals
random_offset: If `True`, offset will be randomized when loading a subsegment
from a file.
+ normalization_signal: Normalize all audio with a factor that ensures the signal
+ `example[normalization_signal]` in `process` is in range [-1, 1].
+ All other audio signals are scaled by the same factor. Default is
+ `None`, corresponding to no normalization.
"""
def __init__(
- self, sample_rate: float, random_offset: bool,
+ self, sample_rate: float, random_offset: bool, normalization_signal: Optional[str] = None, eps: float = 1e-8,
):
self.sample_rate = sample_rate
self.random_offset = random_offset
+ self.normalization_signal = normalization_signal
+ self.eps = eps
self.sync_setup = None
self.async_setup = None
@@ -314,7 +320,20 @@ def process_audio(self, audio: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso
Returns:
An ordered dictionary of signals and their tensors.
"""
- # Currently, not doing any processing of the loaded signals.
+ if self.normalization_signal:
+ # Normalize all audio with a factor that ensures the normalization signal is in range [-1, 1].
+ norm_scale = audio[self.normalization_signal].abs().max()
+
+ # Do not normalize embeddings
+ skip_signals = self.embedding_setup.signals if self.embedding_setup is not None else []
+
+ # Normalize audio signals
+ for signal in audio:
+ if signal not in skip_signals:
+ # All audio signals are scaled by the same factor.
+ # This ensures that the relative level between signals is preserved.
+ audio[signal] = audio[signal] / (norm_scale + self.eps)
+
return audio
def load_sync_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
@@ -812,6 +831,9 @@ class AudioToTargetDataset(BaseAudioDataset):
If `None`, all channels will be loaded.
target_channel_selector: Optional, select subset of channels from each input audio file.
If `None`, all channels will be loaded.
+ normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1].
+ All audio signals are scaled by the same factor. Supported values are `None` (no normalization),
+ 'input_signal', 'target_signal'.
"""
def __init__(
@@ -827,6 +849,7 @@ def __init__(
max_utts: Optional[int] = None,
input_channel_selector: Optional[int] = None,
target_channel_selector: Optional[int] = None,
+ normalization_signal: Optional[str] = None,
):
audio_to_manifest_key = {
'input_signal': input_key,
@@ -841,7 +864,9 @@ def __init__(
max_number=max_utts,
)
- audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
+ audio_processor = ASRAudioProcessor(
+ sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal,
+ )
audio_processor.sync_setup = SignalSetup(
signals=['input_signal', 'target_signal'],
duration=audio_duration,
@@ -932,6 +957,9 @@ class AudioToTargetWithReferenceDataset(BaseAudioDataset):
from input and target.
reference_duration: Optional, can be used to set a fixed duration of the reference utterance. If `None`,
complete audio file will be loaded.
+ normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1].
+ All audio signals are scaled by the same factor. Supported values are `None` (no normalization),
+ 'input_signal', 'target_signal', 'reference_signal'.
"""
def __init__(
@@ -951,6 +979,7 @@ def __init__(
reference_channel_selector: Optional[int] = None,
reference_is_synchronized: bool = True,
reference_duration: Optional[float] = None,
+ normalization_signal: Optional[str] = None,
):
audio_to_manifest_key = {
'input_signal': input_key,
@@ -966,7 +995,9 @@ def __init__(
max_number=max_utts,
)
- audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
+ audio_processor = ASRAudioProcessor(
+ sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal,
+ )
if reference_is_synchronized:
audio_processor.sync_setup = SignalSetup(
@@ -1063,6 +1094,9 @@ class AudioToTargetWithEmbeddingDataset(BaseAudioDataset):
If `None`, all channels will be loaded.
target_channel_selector: Optional, select subset of channels from each input audio file.
If `None`, all channels will be loaded.
+ normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1].
+ All audio signals are scaled by the same factor. Supported values are `None` (no normalization),
+ 'input_signal', 'target_signal'.
"""
def __init__(
@@ -1079,6 +1113,7 @@ def __init__(
max_utts: Optional[int] = None,
input_channel_selector: Optional[int] = None,
target_channel_selector: Optional[int] = None,
+ normalization_signal: Optional[str] = None,
):
audio_to_manifest_key = {
'input_signal': input_key,
@@ -1094,7 +1129,9 @@ def __init__(
max_number=max_utts,
)
- audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
+ audio_processor = ASRAudioProcessor(
+ sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal,
+ )
audio_processor.sync_setup = SignalSetup(
signals=['input_signal', 'target_signal'],
duration=audio_duration,
diff --git a/nemo/collections/asr/data/audio_to_audio_dataset.py b/nemo/collections/asr/data/audio_to_audio_dataset.py
index b296d64b1f2a..46e47020fda0 100644
--- a/nemo/collections/asr/data/audio_to_audio_dataset.py
+++ b/nemo/collections/asr/data/audio_to_audio_dataset.py
@@ -36,6 +36,7 @@ def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDat
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
+ normalization_signal=config.get('normalization_signal', None),
)
return dataset
@@ -65,6 +66,7 @@ def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.A
reference_channel_selector=config.get('reference_channel_selector', None),
reference_is_synchronized=config.get('reference_is_synchronized', True),
reference_duration=config.get('reference_duration', None),
+ normalization_signal=config.get('normalization_signal', None),
)
return dataset
@@ -91,5 +93,6 @@ def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.A
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
+ normalization_signal=config.get('normalization_signal', None),
)
return dataset
diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py
index 3e50cea1d692..c03f7a48ffe3 100644
--- a/nemo/collections/asr/losses/__init__.py
+++ b/nemo/collections/asr/losses/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss
-from nemo.collections.asr.losses.audio_losses import SDRLoss
+from nemo.collections.asr.losses.audio_losses import MSELoss, SDRLoss
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.losses.lattice_losses import LatticeLoss
from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss
diff --git a/nemo/collections/asr/losses/audio_losses.py b/nemo/collections/asr/losses/audio_losses.py
index 62ce4a9f7edd..b0214375a713 100644
--- a/nemo/collections/asr/losses/audio_losses.py
+++ b/nemo/collections/asr/losses/audio_losses.py
@@ -13,7 +13,7 @@
# limitations under the License.
import math
-from typing import List, Optional
+from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -21,31 +21,33 @@
from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like
from nemo.collections.asr.parts.utils.audio_utils import toeplitz
from nemo.core.classes import Loss, Typing, typecheck
-from nemo.core.neural_types import AudioSignal, LengthsType, LossType, MaskType, NeuralType
+from nemo.core.neural_types import AudioSignal, LengthsType, LossType, MaskType, NeuralType, VoidType
from nemo.utils import logging
-__all__ = ['SDRLoss']
+__all__ = ['SDRLoss', 'MSELoss']
-def temporal_mean(
+def calculate_mean(
input: torch.Tensor,
input_length: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
+ dim: Union[int, Tuple[int]] = -1,
keepdim: bool = False,
eps: float = 1e-10,
) -> torch.Tensor:
- """Calculate mean along temporal dimension with optionally
+ """Calculate mean along dimension `dim` with optionally
averaging only over valid samples (based on the input length).
Args:
- input: Batch of signals, shape (B, C, T)
+ input: signal, for example (B, C, T) or (B, C, D, T)
input_length: Optional, length of each example in the batch, shape (B,)
- mask: Optional, temporal mask for each example in the batch, shape (B, T)
+ mask: Optional, temporal mask for each example in the batch, same shape as the input signal
+ dim: dimension or dimensions to reduce
keepdim: Whether to keep the temporal dimension
eps: Regularization to avoid division by zero
Returns:
- (B, C, 1) if keepdim=True, otherwise (B, C)
+ Mean over dimensions `dim`.
"""
if input_length is not None:
if mask is not None:
@@ -53,17 +55,18 @@ def temporal_mean(
'Argument `input_length` is mutually exclusive with `mask`. Both cannot be used at the same time.'
)
# Construct a binary mask
- mask = make_seq_mask_like(lengths=input_length, like=input, time_dim=-1, valid_ones=True).squeeze(1)
+ mask = make_seq_mask_like(lengths=input_length, like=input, time_dim=-1, valid_ones=True)
+ mask = mask.expand_as(input)
if mask is None:
# No length information, assume all samples are valid
- mean = torch.mean(input, dim=-1, keepdim=keepdim)
+ mean = torch.mean(input, dim=dim, keepdim=keepdim)
else:
# Average using temporal mask
- mean = mask.unsqueeze(1) * input
- mean = torch.sum(mean, axis=-1, keepdim=keepdim)
- normalization = torch.sum(mask, axis=-1, keepdim=keepdim)
- mean = mean / (normalization.unsqueeze(1) + eps)
+ mean = mask * input
+ mean = torch.sum(mean, dim=dim, keepdim=keepdim)
+ normalization = torch.sum(mask, dim=dim, keepdim=keepdim)
+ mean = mean / (normalization + eps)
return mean
@@ -101,16 +104,17 @@ def scale_invariant_target(
)
# Construct a binary mask
- mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1)
+ mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True)
+ mask = mask.expand_as(estimate)
- estimate_dot_target = temporal_mean(estimate * target, mask=mask, keepdim=True, eps=eps)
- target_pow = temporal_mean(torch.abs(target) ** 2, mask=mask, keepdim=True, eps=eps)
+ estimate_dot_target = calculate_mean(estimate * target, mask=mask, dim=-1, keepdim=True, eps=eps)
+ target_pow = calculate_mean(torch.abs(target) ** 2, mask=mask, dim=-1, keepdim=True, eps=eps)
scale = estimate_dot_target / (target_pow + eps)
target_scaled = scale * target
# Mask to keep only the valid samples
if mask is not None:
- target_scaled = mask.unsqueeze(1) * target_scaled
+ target_scaled = mask * target_scaled
return target_scaled
@@ -162,12 +166,13 @@ def convolution_invariant_target(
)
# Construct a binary mask
- mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1)
+ mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True)
+ mask = mask.expand_as(estimate)
# Apply a mask, if available
if mask is not None:
- estimate = mask.unsqueeze(1) * estimate
- target = mask.unsqueeze(1) * target
+ estimate = mask * estimate
+ target = mask * target
# Calculate filtered target
input_shape = estimate.shape
@@ -207,7 +212,7 @@ def convolution_invariant_target(
# Mask to keep only the valid samples
if mask is not None:
- target_filt = mask.unsqueeze(1) * target_filt
+ target_filt = mask * target_filt
return target_filt
@@ -261,11 +266,12 @@ def calculate_sdr_batch(
)
# Construct a binary mask
- mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1)
+ mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True)
+ mask = mask.expand_as(estimate)
if remove_mean:
- estimate = estimate - temporal_mean(estimate, mask=mask, keepdim=True, eps=eps)
- target = target - temporal_mean(target, mask=mask, keepdim=True, eps=eps)
+ estimate = estimate - calculate_mean(estimate, mask=mask, dim=-1, keepdim=True, eps=eps)
+ target = target - calculate_mean(target, mask=mask, dim=-1, keepdim=True, eps=eps)
if scale_invariant or (convolution_invariant and convolution_filter_length == 1):
target = scale_invariant_target(estimate=estimate, target=target, mask=mask, eps=eps)
@@ -276,8 +282,8 @@ def calculate_sdr_batch(
distortion = estimate - target
- target_pow = temporal_mean(torch.abs(target) ** 2, mask=mask, eps=eps)
- distortion_pow = temporal_mean(torch.abs(distortion) ** 2, mask=mask, eps=eps)
+ target_pow = calculate_mean(torch.abs(target) ** 2, mask=mask, dim=-1, eps=eps)
+ distortion_pow = calculate_mean(torch.abs(distortion) ** 2, mask=mask, dim=-1, eps=eps)
if sdr_max is not None:
distortion_pow = distortion_pow + 10 ** (-sdr_max / 10) * target_pow
@@ -353,7 +359,7 @@ def input_types(self):
"estimate": NeuralType(signal_shape, AudioSignal()),
"target": NeuralType(signal_shape, AudioSignal()),
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
- "mask": NeuralType(('B', 'T'), MaskType(), optional=True),
+ "mask": NeuralType(('B', 'C', 'T'), MaskType(), optional=True),
}
@property
@@ -376,10 +382,10 @@ def forward(
perform averaging across channels (weighting optional), and apply reduction across the batch.
Args:
- estimate: Batch of signals, shape (B, T, C)
- target: Batch of signals, shape (B, T, C)
+ estimate: Batch of signals, shape (B, C, T)
+ target: Batch of signals, shape (B, C, T)
input_length: Batch of lengths, shape (B,)
- mask: Batch of temporal masks, shape (B, T)
+ mask: Batch of temporal masks for each channel, shape (B, C, T)
Returns:
Scalar loss.
@@ -410,3 +416,161 @@ def forward(
sdr = self.reduce(sdr)
return -sdr
+
+
+def calculate_mse_batch(
+ estimate: torch.Tensor,
+ target: torch.Tensor,
+ input_length: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ """Calculate MSE per channel.
+
+ MSE = ||estimate - target||_2^2 / input_length
+
+ Args:
+ estimate: estimated signal, shape (B, C, T) or (B, C, D, T)
+ target: target signal, shape (B, C, T) or (B, C, D, T)
+ input_length: Optional, length of valid samples, shape (B,)
+ mask: Optional, temporal mask, same shape as signals
+
+ Returns:
+ MSE for each channel, shape (B, C)
+ """
+ assert (
+ estimate.shape == target.shape
+ ), f'Estimate shape ({estimate.shape}) not matching target shape ({target.shape})'
+
+ if input_length is not None:
+ if mask is not None:
+ raise RuntimeError(
+ 'Argument `input_length` is mutually exclusive with `mask`. Both cannot be used at the same time.'
+ )
+
+ # Construct a binary mask
+ mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True)
+ mask = mask.expand_as(estimate)
+
+ # error
+ err = estimate - target
+
+ # dimensions for averaging
+ if estimate.ndim == 3:
+ # average across time
+ dim = -1
+ elif estimate.ndim == 4:
+ # average across time and features
+ dim = (-2, -1)
+ else:
+ raise RuntimeError(f'Unexpected dimension of the input: {estimate.shape}')
+
+ # calculate masked mean
+ mse = calculate_mean(torch.abs(err) ** 2, mask=mask, dim=dim)
+
+ return mse
+
+
+class MSELoss(Loss, Typing):
+ """
+ Computes MSE loss with weighted average across channels.
+
+ Args:
+ weight: weight for loss of each output channel, used for averaging the loss across channels. Defaults to `None` (averaging).
+ reduction: batch reduction. Defaults to `mean` over the batch.
+ ndim: Number of dimensions for the input signal
+ """
+
+ def __init__(
+ self, weight: Optional[List[float]] = None, reduction: str = 'mean', ndim: int = 3,
+ ):
+ super().__init__()
+
+ # weight buffer
+ if weight is not None:
+ if any([w <= 0 for w in weight]):
+ raise ValueError(f'Weight must be positive! Current value: {weight}')
+ elif not np.isclose(sum(weight), 1, atol=1e-6):
+ raise ValueError(f'Weight should add to one, current weight: {weight}')
+ weight = torch.tensor(weight).reshape(1, -1)
+ logging.info(f'Channel weight set to %s', weight)
+ self.register_buffer('weight', weight)
+ self.weight: Optional[Tensor]
+
+ # Batch reduction
+ self.reduction = reduction
+ if reduction == 'mean':
+ self.reduce = torch.mean
+ else:
+ raise ValueError(f'Unexpected reduction mode {reduction}.')
+
+ # Input dimension
+ self.ndim = ndim
+
+ if self.ndim == 3:
+ # Time-domain input
+ self.signal_shape = ('B', 'C', 'T')
+ elif self.ndim == 4:
+ # Spectral-domain input
+ self.signal_shape = ('B', 'C', 'D', 'T')
+ else:
+ raise ValueError(f'Unexpected input dimension: {self.ndim}')
+
+ logging.debug('Initialized %s with', self.__class__.__name__)
+ logging.debug('\tweight: %s', self.weight)
+ logging.debug('\treduction: %s', self.reduction)
+ logging.debug('\tndim: %s', self.ndim)
+ logging.debug('\tsignal_shape: %s', self.signal_shape)
+
+ @property
+ def input_types(self):
+ """Input types definitions for SDRLoss.
+ """
+ return {
+ "estimate": NeuralType(self.signal_shape, VoidType()),
+ "target": NeuralType(self.signal_shape, VoidType()),
+ "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ "mask": NeuralType(self.signal_shape, MaskType(), optional=True),
+ }
+
+ @property
+ def output_types(self):
+ """Output types definitions for SDRLoss.
+ loss:
+ NeuralType(None)
+ """
+ return {"loss": NeuralType(elements_type=LossType())}
+
+ @typecheck()
+ def forward(
+ self,
+ estimate: torch.Tensor,
+ target: torch.Tensor,
+ input_length: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """For input batch of multi-channel signals, calculate SDR between estimate and target for each channel,
+ perform averaging across channels (weighting optional), and apply reduction across the batch.
+
+ Args:
+ estimate: Estimate of the target signal
+ target: Target signal
+ input_length: Length of each example in the batch
+ mask: Mask for each signal
+
+ Returns:
+ Scalar loss.
+ """
+ mse = calculate_mse_batch(estimate=estimate, target=target, input_length=input_length, mask=mask,)
+
+ # channel averaging
+ if self.weight is None:
+ mse = torch.mean(mse, dim=1)
+ else:
+ # weighting across channels
+ mse = mse * self.weight
+ mse = torch.sum(mse, dim=1)
+
+ # reduction
+ mse = self.reduce(mse)
+
+ return mse
diff --git a/nemo/collections/asr/metrics/audio.py b/nemo/collections/asr/metrics/audio.py
index 5e8c2915e3fa..db63ac19c098 100644
--- a/nemo/collections/asr/metrics/audio.py
+++ b/nemo/collections/asr/metrics/audio.py
@@ -57,6 +57,7 @@ class AudioMetricWrapper(Metric):
"""
full_state_update: bool = False
+ num_examples: torch.Tensor
def __init__(
self, metric: Metric, channel: Optional[int] = None, metric_using_batch_averaging: Optional[bool] = None
@@ -74,6 +75,7 @@ def __init__(
self._metric = metric
self._channel = channel
+ self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum')
logging.debug('Setup metric %s, channel %s', metric, str(channel))
def _select_channel(self, preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -144,6 +146,8 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, input_length: Option
for b_preds, b_target in self._trim_inputs(preds=preds, target=target, input_length=input_length):
self._metric.update(preds=b_preds, target=b_target)
+ self.num_examples += preds.size(0)
+
def compute(self) -> torch.Tensor:
"""Compute the underlying metric.
"""
@@ -179,6 +183,9 @@ def forward(
def reset(self) -> None:
"""Reset the underlying metric.
"""
+ # reset the internal states
+ super().reset()
+ # reset the underlying metric
self._metric.reset()
def __repr__(self) -> str:
diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py
index 019c57f9c4e3..23c759afc80d 100644
--- a/nemo/collections/asr/models/__init__.py
+++ b/nemo/collections/asr/models/__init__.py
@@ -23,7 +23,11 @@
from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
-from nemo.collections.asr.models.enhancement_models import EncMaskDecAudioToAudioModel
+from nemo.collections.asr.models.enhancement_models import (
+ EncMaskDecAudioToAudioModel,
+ PredictiveAudioToAudioModel,
+ ScoreBasedGenerativeAudioToAudioModel,
+)
from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
from nemo.collections.asr.models.k2_sequence_models import (
diff --git a/nemo/collections/asr/models/audio_to_audio_model.py b/nemo/collections/asr/models/audio_to_audio_model.py
index 49364843e8b8..094dbc38b72a 100644
--- a/nemo/collections/asr/models/audio_to_audio_model.py
+++ b/nemo/collections/asr/models/audio_to_audio_model.py
@@ -12,15 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+import os
+import tempfile
from abc import ABC, abstractmethod
-from typing import List, Union
+from typing import Dict, List, Optional, Union
import hydra
+import librosa
+import soundfile as sf
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
+from tqdm import tqdm
+from nemo.collections.asr.data import audio_to_audio_dataset
+from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset
+from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config
from nemo.collections.asr.metrics.audio import AudioMetricWrapper
+from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
+from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.core.classes import ModelPT
from nemo.utils import logging, model_utils
@@ -158,23 +169,384 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
return self.multi_evaluation_epoch_end(outputs, dataloader_idx, 'test')
- @abstractmethod
+ @torch.no_grad()
def process(
- self, paths2audio_files: List[str], output_dir: str, batch_size: int = 4
- ) -> List[Union[str, List[str]]]:
+ self,
+ paths2audio_files: List[str],
+ output_dir: str,
+ batch_size: int = 1,
+ num_workers: Optional[int] = None,
+ input_channel_selector: Optional[ChannelSelectorType] = None,
+ ) -> List[str]:
+ """
+ Process audio files provided in paths2audio_files.
+ Processed signals will be saved in output_dir.
+
+ Args:
+ paths2audio_files: (a list) of paths to audio files. \
+ Recommended length per file is between 5 and 25 seconds. \
+ But it is possible to pass a few hours long file if enough GPU memory is available.
+ output_dir:
+ batch_size: (int) batch size to use during inference.
+ Bigger will result in better throughput performance but would use more memory.
+ num_workers: Number of workers for the dataloader
+ input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.
+
+ Returns:
+ """
+ if paths2audio_files is None or len(paths2audio_files) == 0:
+ return {}
+
+ if num_workers is None:
+ num_workers = min(batch_size, os.cpu_count() - 1)
+
+ # Output
+ paths2processed_files = []
+
+ # Model's mode and device
+ mode = self.training
+ device = next(self.parameters()).device
+
+ try:
+ # Switch model to evaluation mode
+ self.eval()
+ # Freeze weights
+ self.freeze()
+
+ logging_level = logging.get_verbosity()
+ logging.set_verbosity(logging.WARNING)
+
+ # Processing
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Save temporary manifest
+ temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json')
+ with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp:
+ for audio_file in paths2audio_files:
+ entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)}
+ fp.write(json.dumps(entry) + '\n')
+
+ config = {
+ 'manifest_filepath': temporary_manifest_filepath,
+ 'input_key': 'input_filepath',
+ 'input_channel_selector': input_channel_selector,
+ 'batch_size': min(batch_size, len(paths2audio_files)),
+ 'num_workers': num_workers,
+ }
+
+ # Create output dir if necessary
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
+
+ # DataLoader for the input files
+ temporary_dataloader = self._setup_process_dataloader(config)
+
+ # Indexing of the original files, used to form the output file name
+ file_idx = 0
+
+ # Process batches
+ for test_batch in tqdm(temporary_dataloader, desc="Processing"):
+ input_signal = test_batch[0]
+ input_length = test_batch[1]
+
+ # Expand channel dimension, if necessary
+ # For consistency, the model uses multi-channel format, even if the channel dimension is 1
+ if input_signal.ndim == 2:
+ input_signal = input_signal.unsqueeze(1)
+
+ processed_batch, _ = self.forward(
+ input_signal=input_signal.to(device), input_length=input_length.to(device)
+ )
+
+ for example_idx in range(processed_batch.size(0)):
+ # This assumes the data loader is not shuffling files
+ file_name = os.path.basename(paths2audio_files[file_idx])
+ # Prepare output file
+ output_file = os.path.join(output_dir, f'processed_{file_name}')
+ # Crop the output signal to the actual length
+ output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy()
+ # Write audio
+ sf.write(output_file, output_signal.T, self.sample_rate, 'float')
+ # Update the file counter
+ file_idx += 1
+ # Save processed file
+ paths2processed_files.append(output_file)
+
+ del test_batch
+ del processed_batch
+
+ finally:
+ # set mode back to its original value
+ self.train(mode=mode)
+ if mode is True:
+ self.unfreeze()
+ logging.set_verbosity(logging_level)
+
+ return paths2processed_files
+
+ def _setup_dataloader_from_config(self, config: Optional[Dict]):
+
+ if config.get("use_lhotse", False):
+ return get_lhotse_dataloader_from_config(
+ config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseAudioToTargetDataset()
+ )
+
+ is_concat = config.get('is_concat', False)
+ if is_concat:
+ raise NotImplementedError('Concat not implemented')
+
+ # TODO: Consider moving `inject` from `audio_to_text_dataset` to a utility module?
+ # Automatically inject args from model config to dataloader config
+ inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
+
+ # Instantiate tarred dataset loader or normal dataset loader
+ if config.get('is_tarred', False):
+ raise NotImplementedError('Tarred datasets not supported')
+
+ if 'manifest_filepath' in config and config['manifest_filepath'] is None:
+ logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
+ return None
+
+ dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config)
+
+ if hasattr(dataset, 'collate_fn'):
+ collate_fn = dataset.collate_fn
+ elif hasattr(dataset.datasets[0], 'collate_fn'):
+ # support datasets that are lists of entries
+ collate_fn = dataset.datasets[0].collate_fn
+ else:
+ # support datasets that are lists of lists
+ collate_fn = dataset.datasets[0].datasets[0].collate_fn
+
+ return torch.utils.data.DataLoader(
+ dataset=dataset,
+ batch_size=config['batch_size'],
+ collate_fn=collate_fn,
+ drop_last=config.get('drop_last', False),
+ shuffle=config['shuffle'],
+ num_workers=config.get('num_workers', 0),
+ pin_memory=config.get('pin_memory', False),
+ )
+
+ def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
+ """
+ Sets up the training data loader via a Dict-like object.
+
+ Args:
+ train_data_config: A config that contains the information regarding construction
+ of a training dataset.
+
+ Supported Datasets:
+ - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
+ """
+ if 'shuffle' not in train_data_config:
+ train_data_config['shuffle'] = True
+
+ # preserve config
+ self._update_dataset_config(dataset_name='train', config=train_data_config)
+
+ self._train_dl = self._setup_dataloader_from_config(config=train_data_config)
+
+ if 'is_tarred' in train_data_config and train_data_config['is_tarred']:
+ raise NotImplementedError('Tarred datasets not supported')
+
+ def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]):
+ """
+ Sets up the validation data loader via a Dict-like object.
+
+ Args:
+ val_data_config: A config that contains the information regarding construction
+ of a validation dataset.
+
+ Supported Datasets:
+ - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
+ """
+ if 'shuffle' not in val_data_config:
+ val_data_config['shuffle'] = False
+
+ # preserve config
+ self._update_dataset_config(dataset_name='validation', config=val_data_config)
+
+ self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)
+
+ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
+ """
+ Sets up the test data loader via a Dict-like object.
+
+ Args:
+ test_data_config: A config that contains the information regarding construction
+ of a test dataset.
+
+ Supported Datasets:
+ - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
+ """
+ if 'shuffle' not in test_data_config:
+ test_data_config['shuffle'] = False
+
+ # preserve config
+ self._update_dataset_config(dataset_name='test', config=test_data_config)
+
+ self._test_dl = self._setup_dataloader_from_config(config=test_data_config)
+
+ def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
+ """Prepare a dataloader for processing files.
+
+ Args:
+ config: A python dictionary which contains the following keys:
+ manifest_filepath: path to a manifest file
+ input_key: key with audio filepaths in the manifest
+ input_channel_selector: Optional, used to select a subset of channels from input audio files
+ batch_size: batch size for the dataloader
+ num_workers: number of workers for the dataloader
+
+ Returns:
+ A pytorch DataLoader for the given manifest filepath.
+ """
+ dl_config = {
+ 'manifest_filepath': config['manifest_filepath'],
+ 'sample_rate': self.sample_rate,
+ 'input_key': config['input_key'],
+ 'input_channel_selector': config.get('input_channel_selector', None),
+ 'target_key': None,
+ 'target_channel_selector': None,
+ 'batch_size': config['batch_size'],
+ 'shuffle': False,
+ 'num_workers': config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)),
+ 'pin_memory': True,
+ }
+
+ temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config))
+ return temporary_dataloader
+
+ @staticmethod
+ def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor:
+ """Trim or pad the output to match the batch length.
+
+ Args:
+ input: tensor with shape (B, C, T)
+ batch_length: int
+
+ Returns:
+ Tensor with shape (B, C, T), where T matches the
+ batch length.
+ """
+ input_length = input.size(-1)
+ pad_length = batch_length - input_length
+ pad = (0, pad_length)
+ # pad with zeros or crop
+ return torch.nn.functional.pad(input, pad, 'constant', 0)
+
+ @torch.no_grad()
+ def process(
+ self,
+ paths2audio_files: List[str],
+ output_dir: str,
+ batch_size: int = 1,
+ num_workers: Optional[int] = None,
+ input_channel_selector: Optional[ChannelSelectorType] = None,
+ ) -> List[str]:
"""
Takes paths to audio files and returns a list of paths to processed
audios.
Args:
paths2audio_files: paths to audio files to be processed
- output_dir: directory to save processed files
- batch_size: batch size for inference
+ output_dir: directory to save the processed files
+ batch_size: (int) batch size to use during inference.
+ num_workers: Number of workers for the dataloader
+ input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio.
+ If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.
Returns:
Paths to processed audio signals.
"""
- pass
+ if paths2audio_files is None or len(paths2audio_files) == 0:
+ return {}
+
+ if num_workers is None:
+ num_workers = min(batch_size, os.cpu_count() - 1)
+
+ # Output
+ paths2processed_files = []
+
+ # Model's mode and device
+ mode = self.training
+ device = next(self.parameters()).device
+
+ try:
+ # Switch model to evaluation mode
+ self.eval()
+ # Freeze weights
+ self.freeze()
+
+ logging_level = logging.get_verbosity()
+ logging.set_verbosity(logging.WARNING)
+
+ # Processing
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Save temporary manifest
+ temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json')
+ with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp:
+ for audio_file in paths2audio_files:
+ entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)}
+ fp.write(json.dumps(entry) + '\n')
+
+ config = {
+ 'manifest_filepath': temporary_manifest_filepath,
+ 'input_key': 'input_filepath',
+ 'input_channel_selector': input_channel_selector,
+ 'batch_size': min(batch_size, len(paths2audio_files)),
+ 'num_workers': num_workers,
+ }
+
+ # Create output dir if necessary
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
+
+ # DataLoader for the input files
+ temporary_dataloader = self._setup_process_dataloader(config)
+
+ # Indexing of the original files, used to form the output file name
+ file_idx = 0
+
+ # Process batches
+ for test_batch in tqdm(temporary_dataloader, desc="Processing"):
+ input_signal = test_batch[0]
+ input_length = test_batch[1]
+
+ # Expand channel dimension, if necessary
+ # For consistency, the model uses multi-channel format, even if the channel dimension is 1
+ if input_signal.ndim == 2:
+ input_signal = input_signal.unsqueeze(1)
+
+ processed_batch, _ = self.forward(
+ input_signal=input_signal.to(device), input_length=input_length.to(device)
+ )
+
+ for example_idx in range(processed_batch.size(0)):
+ # This assumes the data loader is not shuffling files
+ file_name = os.path.basename(paths2audio_files[file_idx])
+ # Prepare output file
+ output_file = os.path.join(output_dir, f'processed_{file_name}')
+ # Crop the output signal to the actual length
+ output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy()
+ # Write audio
+ sf.write(output_file, output_signal.T, self.sample_rate, 'float')
+ # Update the file counter
+ file_idx += 1
+ # Save processed file
+ paths2processed_files.append(output_file)
+
+ del test_batch
+ del processed_batch
+
+ finally:
+ # set mode back to its original value
+ self.train(mode=mode)
+ if mode is True:
+ self.unfreeze()
+ logging.set_verbosity(logging_level)
+
+ return paths2processed_files
@classmethod
def list_available_models(cls) -> 'List[PretrainedModelInfo]':
diff --git a/nemo/collections/asr/models/enhancement_models.py b/nemo/collections/asr/models/enhancement_models.py
index b80c357364aa..b765ae0fddad 100644
--- a/nemo/collections/asr/models/enhancement_models.py
+++ b/nemo/collections/asr/models/enhancement_models.py
@@ -16,6 +16,8 @@
import tempfile
from typing import Dict, List, Optional, Union
+import einops
+import hydra
import librosa
import soundfile as sf
import torch
@@ -23,17 +25,13 @@
from pytorch_lightning import Trainer
from tqdm import tqdm
-from nemo.collections.asr.data import audio_to_audio_dataset
-from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset
-from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config
+
from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel
-from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
-from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.core.classes.common import PretrainedModelInfo, typecheck
-from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType
+from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType
from nemo.utils import logging
-__all__ = ['EncMaskDecAudioToAudioModel']
+__all__ = ['EncMaskDecAudioToAudioModel', 'ScoreBasedGenerativeAudioToAudioModel', 'PredictiveAudioToAudioModel']
class EncMaskDecAudioToAudioModel(AudioToAudioModel):
@@ -69,10 +67,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
logging.debug('Mixture consistency not used')
self.mixture_consistency = None
- # Future enhancement:
- # If subclasses need to modify the config before calling super()
- # Check ASRBPE* classes do with their mixin
-
# Setup augmentation
if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None:
logging.debug('Using channel augmentation')
@@ -84,254 +78,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Setup optional Optimization flags
self.setup_optimization_flags()
- @torch.no_grad()
- def process(
- self,
- paths2audio_files: List[str],
- output_dir: str,
- batch_size: int = 1,
- num_workers: Optional[int] = None,
- input_channel_selector: Optional[ChannelSelectorType] = None,
- ) -> List[str]:
- """
- Process audio files provided in paths2audio_files.
- Processed signals will be saved in output_dir.
-
- Args:
- paths2audio_files: (a list) of paths to audio files. \
- Recommended length per file is between 5 and 25 seconds. \
- But it is possible to pass a few hours long file if enough GPU memory is available.
- output_dir:
- batch_size: (int) batch size to use during inference.
- Bigger will result in better throughput performance but would use more memory.
- num_workers: Number of workers for the dataloader
- input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.
-
- Returns:
- """
- if paths2audio_files is None or len(paths2audio_files) == 0:
- return {}
-
- if num_workers is None:
- num_workers = min(batch_size, os.cpu_count() - 1)
-
- # Output
- paths2processed_files = []
-
- # Model's mode and device
- mode = self.training
- device = next(self.parameters()).device
-
- try:
- # Switch model to evaluation mode
- self.eval()
- # Freeze weights
- self.freeze()
-
- logging_level = logging.get_verbosity()
- logging.set_verbosity(logging.WARNING)
-
- # Processing
- with tempfile.TemporaryDirectory() as tmpdir:
- # Save temporary manifest
- temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json')
- with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp:
- for audio_file in paths2audio_files:
- entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)}
- fp.write(json.dumps(entry) + '\n')
-
- config = {
- 'manifest_filepath': temporary_manifest_filepath,
- 'input_key': 'input_filepath',
- 'input_channel_selector': input_channel_selector,
- 'batch_size': min(batch_size, len(paths2audio_files)),
- 'num_workers': num_workers,
- }
-
- # Create output dir if necessary
- if not os.path.isdir(output_dir):
- os.makedirs(output_dir)
-
- # DataLoader for the input files
- temporary_dataloader = self._setup_process_dataloader(config)
-
- # Indexing of the original files, used to form the output file name
- file_idx = 0
-
- # Process batches
- for test_batch in tqdm(temporary_dataloader, desc="Processing"):
- input_signal = test_batch[0]
- input_length = test_batch[1]
-
- # Expand channel dimension, if necessary
- # For consistency, the model uses multi-channel format, even if the channel dimension is 1
- if input_signal.ndim == 2:
- input_signal = input_signal.unsqueeze(1)
-
- processed_batch, _ = self.forward(
- input_signal=input_signal.to(device), input_length=input_length.to(device)
- )
-
- for example_idx in range(processed_batch.size(0)):
- # This assumes the data loader is not shuffling files
- file_name = os.path.basename(paths2audio_files[file_idx])
- # Prepare output file
- output_file = os.path.join(output_dir, f'processed_{file_name}')
- # Crop the output signal to the actual length
- output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy()
- # Write audio
- sf.write(output_file, output_signal.T, self.sample_rate, 'float')
- # Update the file counter
- file_idx += 1
- # Save processed file
- paths2processed_files.append(output_file)
-
- del test_batch
- del processed_batch
-
- finally:
- # set mode back to its original value
- self.train(mode=mode)
- if mode is True:
- self.unfreeze()
- logging.set_verbosity(logging_level)
-
- return paths2processed_files
-
- def _setup_dataloader_from_config(self, config: Optional[Dict]):
-
- if config.get("use_lhotse", False):
- return get_lhotse_dataloader_from_config(
- config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseAudioToTargetDataset()
- )
-
- is_concat = config.get('is_concat', False)
- if is_concat:
- raise NotImplementedError('Concat not implemented')
-
- # TODO: Consider moving `inject` from `audio_to_text_dataset` to a utility module?
- # Automatically inject args from model config to dataloader config
- inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
-
- # Instantiate tarred dataset loader or normal dataset loader
- if config.get('is_tarred', False):
- raise NotImplementedError('Tarred datasets not supported')
-
- if 'manifest_filepath' in config and config['manifest_filepath'] is None:
- logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
- return None
-
- dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config)
-
- if hasattr(dataset, 'collate_fn'):
- collate_fn = dataset.collate_fn
- elif hasattr(dataset.datasets[0], 'collate_fn'):
- # support datasets that are lists of entries
- collate_fn = dataset.datasets[0].collate_fn
- else:
- # support datasets that are lists of lists
- collate_fn = dataset.datasets[0].datasets[0].collate_fn
-
- return torch.utils.data.DataLoader(
- dataset=dataset,
- batch_size=config['batch_size'],
- collate_fn=collate_fn,
- drop_last=config.get('drop_last', False),
- shuffle=config['shuffle'],
- num_workers=config.get('num_workers', 0),
- pin_memory=config.get('pin_memory', False),
- )
-
- def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
- """
- Sets up the training data loader via a Dict-like object.
-
- Args:
- train_data_config: A config that contains the information regarding construction
- of a training dataset.
-
- Supported Datasets:
- - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
- """
- if 'shuffle' not in train_data_config:
- train_data_config['shuffle'] = True
-
- # preserve config
- self._update_dataset_config(dataset_name='train', config=train_data_config)
-
- self._train_dl = self._setup_dataloader_from_config(config=train_data_config)
-
- if 'is_tarred' in train_data_config and train_data_config['is_tarred']:
- raise NotImplementedError('Tarred datasets not supported')
-
- def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]):
- """
- Sets up the validation data loader via a Dict-like object.
-
- Args:
- val_data_config: A config that contains the information regarding construction
- of a validation dataset.
-
- Supported Datasets:
- - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
- """
- if 'shuffle' not in val_data_config:
- val_data_config['shuffle'] = False
-
- # preserve config
- self._update_dataset_config(dataset_name='validation', config=val_data_config)
-
- self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)
-
- def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
- """
- Sets up the test data loader via a Dict-like object.
-
- Args:
- test_data_config: A config that contains the information regarding construction
- of a test dataset.
-
- Supported Datasets:
- - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
- """
- if 'shuffle' not in test_data_config:
- test_data_config['shuffle'] = False
-
- # preserve config
- self._update_dataset_config(dataset_name='test', config=test_data_config)
-
- self._test_dl = self._setup_dataloader_from_config(config=test_data_config)
-
- def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
- """Prepare a dataloader for processing files.
-
- Args:
- config: A python dictionary which contains the following keys:
- manifest_filepath: path to a manifest file
- input_key: key with audio filepaths in the manifest
- input_channel_selector: Optional, used to select a subset of channels from input audio files
- batch_size: batch size for the dataloader
- num_workers: number of workers for the dataloader
-
- Returns:
- A pytorch DataLoader for the given manifest filepath.
- """
- dl_config = {
- 'manifest_filepath': config['manifest_filepath'],
- 'sample_rate': self.sample_rate,
- 'input_key': config['input_key'],
- 'input_channel_selector': config.get('input_channel_selector', None),
- 'target_key': None,
- 'target_channel_selector': None,
- 'batch_size': config['batch_size'],
- 'shuffle': False,
- 'num_workers': config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)),
- 'pin_memory': True,
- }
-
- temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config))
- return temporary_dataloader
-
@property
def input_types(self) -> Dict[str, NeuralType]:
return {
@@ -350,23 +96,6 @@ def output_types(self) -> Dict[str, NeuralType]:
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
- def match_batch_length(self, input: torch.Tensor, batch_length: int):
- """Trim or pad the output to match the batch length.
-
- Args:
- input: tensor with shape (B, C, T)
- batch_length: int
-
- Returns:
- Tensor with shape (B, C, T), where T matches the
- batch length.
- """
- input_length = input.size(-1)
- pad_length = batch_length - input_length
- pad = (0, pad_length)
- # pad with zeros or crop
- return torch.nn.functional.pad(input, pad, 'constant', 0)
-
@typecheck()
def forward(self, input_signal, input_length=None):
"""
@@ -380,6 +109,7 @@ def forward(self, input_signal, input_length=None):
sequences.
Returns:
+ Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
batch_length = input_signal.size(-1)
@@ -414,12 +144,11 @@ def training_step(self, batch, batch_idx):
else:
input_signal, input_length, target_signal, _ = batch
- # Expand channel dimension, if necessary
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
- input_signal = input_signal.unsqueeze(1)
+ input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
- target_signal = target_signal.unsqueeze(1)
+ target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Apply channel augmentation
if self.training and self.channel_augmentation is not None:
@@ -449,12 +178,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
else:
input_signal, input_length, target_signal, _ = batch
- # Expand channel dimension, if necessary
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
- input_signal = input_signal.unsqueeze(1)
+ input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
- target_signal = target_signal.unsqueeze(1)
+ target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Process input
processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)
@@ -485,3 +213,406 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]:
results = []
return results
+
+
+class PredictiveAudioToAudioModel(AudioToAudioModel):
+ """This models aims to directly estimate the coefficients
+ in the encoded domain by applying a neural model.
+ """
+
+ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
+ super().__init__(cfg=cfg, trainer=trainer)
+ self.sample_rate = self._cfg.sample_rate
+
+ # Setup processing modules
+ self.encoder = self.from_config_dict(self._cfg.encoder)
+ self.decoder = self.from_config_dict(self._cfg.decoder)
+
+ # Neural estimator
+ self.estimator = self.from_config_dict(self._cfg.estimator)
+
+ # Normalization
+ self.normalize_input = self._cfg.get('normalize_input', False)
+
+ # Term added to the denominator to improve numerical stability
+ self.eps = self._cfg.get('eps', 1e-8)
+
+ # Setup optional Optimization flags
+ self.setup_optimization_flags()
+
+ logging.debug('Initialized %s', self.__class__.__name__)
+ logging.debug('\tnormalize_input: %s', self.normalize_input)
+ logging.debug('\teps: %s', self.eps)
+
+ @property
+ def input_types(self) -> Dict[str, NeuralType]:
+ return {
+ "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
+ "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ }
+
+ @property
+ def output_types(self) -> Dict[str, NeuralType]:
+ return {
+ "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
+ "output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ }
+
+ @typecheck()
+ def forward(self, input_signal, input_length=None):
+ """Forward pass of the model.
+
+ Args:
+ input_signal: time-domain signal
+ input_length: valid length of each example in the batch
+
+ Returns:
+ Output signal `output` in the time domain and the length of the output signal `output_length`.
+ """
+ batch_length = input_signal.size(-1)
+
+ if self.normalize_input:
+ # max for each example in the batch
+ norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
+ # scale input signal
+ input_signal = input_signal / (norm_scale + self.eps)
+
+ # Encoder
+ encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)
+
+ # Backbone
+ estimated, estimated_length = self.estimator(input=encoded, input_length=encoded_length)
+
+ # Decoder
+ output, output_length = self.decoder(input=estimated, input_length=estimated_length)
+
+ if self.normalize_input:
+ # rescale to the original scale
+ output = output * norm_scale
+
+ # Trim or pad the estimated signal to match input length
+ output = self.match_batch_length(input=output, batch_length=batch_length)
+ return output, output_length
+
+ # PTL-specific methods
+ def training_step(self, batch, batch_idx):
+
+ if isinstance(batch, dict):
+ # lhotse batches are dictionaries
+ input_signal = batch['input_signal']
+ input_length = batch['input_length']
+ target_signal = batch['target_signal']
+ else:
+ input_signal, input_length, target_signal, _ = batch
+
+ # For consistency, the model uses multi-channel format, even if the channel dimension is 1
+ if input_signal.ndim == 2:
+ input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
+ if target_signal.ndim == 2:
+ target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
+
+ # Estimate the signal
+ output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)
+
+ # Calculate the loss
+ loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length)
+
+ # Logs
+ self.log('train_loss', loss)
+ self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
+ self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
+
+ return loss
+
+ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
+
+ if isinstance(batch, dict):
+ # lhotse batches are dictionaries
+ input_signal = batch['input_signal']
+ input_length = batch['input_length']
+ target_signal = batch['target_signal']
+ else:
+ input_signal, input_length, target_signal, _ = batch
+
+ # For consistency, the model uses multi-channel format, even if the channel dimension is 1
+ if input_signal.ndim == 2:
+ input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
+ if target_signal.ndim == 2:
+ target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
+
+ # Estimate the signal
+ output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)
+
+ # Prepare output
+ loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length)
+
+ # Update metrics
+ if hasattr(self, 'metrics') and tag in self.metrics:
+ # Update metrics for this (tag, dataloader_idx)
+ for name, metric in self.metrics[tag][dataloader_idx].items():
+ metric.update(preds=output_signal, target=target_signal, input_length=input_length)
+
+ # Log global step
+ self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
+
+ return {f'{tag}_loss': loss}
+
+
+class ScoreBasedGenerativeAudioToAudioModel(AudioToAudioModel):
+ """This models is using a score-based diffusion process to generate
+ an encoded representation of the enhanced signal.
+
+ The model consists of the following blocks:
+ - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform)
+ - estimator: neural model, estimates a score for the diffusion process
+ - sde: stochastic differential equation (SDE) defining the forward and reverse diffusion process
+ - sampler: sampler for the reverse diffusion process, estimates coefficients of the target signal
+ - decoder: transforms sampler output into the time domain (synthesis transform)
+ """
+
+ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
+ super().__init__(cfg=cfg, trainer=trainer)
+ self.sample_rate = self._cfg.sample_rate
+
+ # Setup processing modules
+ self.encoder = self.from_config_dict(self._cfg.encoder)
+ self.decoder = self.from_config_dict(self._cfg.decoder)
+
+ # Neural score estimator
+ self.estimator = self.from_config_dict(self._cfg.estimator)
+
+ # SDE
+ self.sde = self.from_config_dict(self._cfg.sde)
+
+ # Sampler
+ if 'sde' in self._cfg.sampler:
+ raise ValueError('SDE should be defined in the model config, not in the sampler config')
+ if 'score_estimator' in self._cfg.sampler:
+ raise ValueError('Score estimator should be defined in the model config, not in the sampler config')
+
+ self.sampler = hydra.utils.instantiate(self._cfg.sampler, sde=self.sde, score_estimator=self.estimator)
+
+ # Normalization
+ self.normalize_input = self._cfg.get('normalize_input', False)
+
+ # Metric evaluation
+ self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics')
+
+ if self.max_utts_evaluation_metrics is not None:
+ logging.warning(
+ 'Metrics will be evaluated on first %d examples of the evaluation datasets.',
+ self.max_utts_evaluation_metrics,
+ )
+
+ # Term added to the denominator to improve numerical stability
+ self.eps = self._cfg.get('eps', 1e-8)
+
+ # Setup optional Optimization flags
+ self.setup_optimization_flags()
+
+ logging.debug('Initialized %s', self.__class__.__name__)
+ logging.debug('\tnormalize_input: %s', self.normalize_input)
+ logging.debug('\teps: %s', self.eps)
+
+ @property
+ def input_types(self) -> Dict[str, NeuralType]:
+ return {
+ "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
+ "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ }
+
+ @property
+ def output_types(self) -> Dict[str, NeuralType]:
+ return {
+ "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
+ "output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ }
+
+ @typecheck()
+ @torch.inference_mode()
+ def forward(self, input_signal, input_length=None):
+ """Forward pass of the model.
+
+ Forward pass of the model aplies the following steps:
+ - encoder to obtain the encoded representation of the input signal
+ - sampler to generate the estimated coefficients of the target signal
+ - decoder to transform the sampler output into the time domain
+
+ Args:
+ input_signal: Tensor that represents a batch of raw audio signals,
+ of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
+ `self.sample_rate` number of floating point values.
+ input_signal_length: Vector of length B, that contains the individual lengths of the audio
+ sequences.
+
+ Returns:
+ Output signal `output` in the time domain and the length of the output signal `output_length`.
+ """
+ batch_length = input_signal.size(-1)
+
+ if self.normalize_input:
+ # max for each example in the batch
+ norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
+ # scale input signal
+ input_signal = input_signal / (norm_scale + self.eps)
+
+ # Encoder
+ encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)
+
+ # Sampler
+ generated, generated_length = self.sampler(
+ prior_mean=encoded, score_condition=encoded, state_length=encoded_length
+ )
+
+ # Decoder
+ output, output_length = self.decoder(input=generated, input_length=generated_length)
+
+ if self.normalize_input:
+ # rescale to the original scale
+ output = output * norm_scale
+
+ # Trim or pad the estimated signal to match input length
+ output = self.match_batch_length(input=output, batch_length=batch_length)
+ return output, output_length
+
+ @typecheck(
+ input_types={
+ "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
+ "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
+ "input_length": NeuralType(tuple('B'), LengthsType()),
+ },
+ output_types={"loss": NeuralType(None, LossType()),},
+ )
+ def _step(self, target_signal, input_signal, input_length=None):
+ """Randomly generate a time step for each example in the batch, estimate
+ the score and calculate the loss value.
+
+ Note that this step does not include sampler.
+ """
+ batch_size = target_signal.size(0)
+
+ if self.normalize_input:
+ # max for each example in the batch
+ norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
+ # scale input signal
+ input_signal = input_signal / (norm_scale + self.eps)
+ # scale the target signal
+ target_signal = target_signal / (norm_scale + self.eps)
+
+ # Apply encoder to both target and the input
+ input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length)
+ target_enc, _ = self.encoder(input=target_signal, input_length=input_length)
+
+ # Generate random time steps
+ sde_time = self.sde.generate_time(size=batch_size, device=input_enc.device)
+
+ # Get the mean and the variance of the perturbation kernel
+ pk_mean, pk_std = self.sde.perturb_kernel_params(state=target_enc, prior_mean=input_enc, time=sde_time)
+
+ # Generate a random sample from a standard normal distribution
+ z_norm = torch.randn_like(input_enc)
+
+ # Prepare perturbed data
+ perturbed_enc = pk_mean + pk_std * z_norm
+
+ # Score is conditioned on the perturbed data and the input
+ estimator_input = torch.cat([perturbed_enc, input_enc], dim=-3)
+
+ # Estimate the score using the neural estimator
+ # SDE time is used to inform the estimator about the current time step
+ # Note:
+ # - some implementations use `score = -self._raw_dnn_output(x, t, y)`
+ # - this seems to be unimportant, and is an artifact of transfering code from the original Song's repo
+ score_est, score_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=sde_time)
+
+ # Score loss weighting as in Section 4.2 in http://arxiv.org/abs/1907.05600
+ score_est = score_est * pk_std
+ score_ref = -z_norm
+
+ # Score matching loss on the normalized scores
+ loss = self.loss(estimate=score_est, target=score_ref, input_length=score_len)
+
+ return loss
+
+ # PTL-specific methods
+ def training_step(self, batch, batch_idx):
+
+ if isinstance(batch, dict):
+ # lhotse batches are dictionaries
+ input_signal = batch['input_signal']
+ input_length = batch['input_length']
+ target_signal = batch['target_signal']
+ else:
+ input_signal, input_length, target_signal, _ = batch
+
+ # For consistency, the model uses multi-channel format, even if the channel dimension is 1
+ if input_signal.ndim == 2:
+ input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
+ if target_signal.ndim == 2:
+ target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
+
+ # Calculate the loss
+ loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)
+
+ # Logs
+ self.log('train_loss', loss)
+ self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
+ self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
+
+ return loss
+
+ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
+
+ if isinstance(batch, dict):
+ # lhotse batches are dictionaries
+ input_signal = batch['input_signal']
+ input_length = batch['input_length']
+ target_signal = batch['target_signal']
+ else:
+ input_signal, input_length, target_signal, _ = batch
+
+ # For consistency, the model uses multi-channel format, even if the channel dimension is 1
+ if input_signal.ndim == 2:
+ input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
+ if target_signal.ndim == 2:
+ target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
+
+ # Calculate loss
+ loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)
+
+ # Update metrics
+ update_metrics = False
+ if self.max_utts_evaluation_metrics is None:
+ # Always update if max is not configured
+ update_metrics = True
+ # Number of examples to process
+ num_examples = input_signal.size(0) # batch size
+ else:
+ # Check how many examples have been used for metric calculation
+ first_metric_name = next(iter(self.metrics[tag][dataloader_idx]))
+ num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples
+ # Update metrics if some examples were not processed
+ update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics
+ # Number of examples to process
+ num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0))
+
+ if update_metrics:
+ # Generate output signal
+ output_signal, _ = self.forward(
+ input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples]
+ )
+
+ # Update metrics
+ if hasattr(self, 'metrics') and tag in self.metrics:
+ # Update metrics for this (tag, dataloader_idx)
+ for name, metric in self.metrics[tag][dataloader_idx].items():
+ metric.update(
+ preds=output_signal,
+ target=target_signal[:num_examples, ...],
+ input_length=input_length[:num_examples],
+ )
+
+ # Log global step
+ self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
+
+ return {f'{tag}_loss': loss}
diff --git a/nemo/collections/asr/modules/audio_modules.py b/nemo/collections/asr/modules/audio_modules.py
index 82cfbefeb8d9..67a923099cde 100644
--- a/nemo/collections/asr/modules/audio_modules.py
+++ b/nemo/collections/asr/modules/audio_modules.py
@@ -17,7 +17,7 @@
import numpy as np
import torch
-from nemo.collections.asr.losses.audio_losses import temporal_mean
+from nemo.collections.asr.losses.audio_losses import calculate_mean
from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder
from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like
from nemo.collections.asr.parts.submodules.multichannel_modules import (
@@ -39,6 +39,7 @@
'MaskReferenceChannel',
'MaskBasedBeamformer',
'MaskBasedDereverbWPE',
+ 'MixtureConsistencyProjection',
]
@@ -158,7 +159,7 @@ def get_mean_time_channel(input: torch.Tensor, input_length: Optional[torch.Tens
mean = torch.mean(input, dim=(-1, -3), keepdim=True)
else:
# temporal mean
- mean = temporal_mean(input, input_length, keepdim=True)
+ mean = calculate_mean(input, input_length, dim=-1, keepdim=True)
# channel mean
mean = torch.mean(mean, dim=-3, keepdim=True)
@@ -186,7 +187,7 @@ def get_mean_std_time_channel(
mean = cls.get_mean_time_channel(input, input_length)
std = (input - mean).pow(2)
# temporal mean
- std = temporal_mean(std, input_length, keepdim=True)
+ std = calculate_mean(std, input_length, dim=-1, keepdim=True)
# channel mean
std = torch.mean(std, dim=-3, keepdim=True)
# final value
diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py
index cc5312403255..643bc4a69d69 100644
--- a/nemo/collections/asr/modules/audio_preprocessing.py
+++ b/nemo/collections/asr/modules/audio_preprocessing.py
@@ -709,9 +709,11 @@ class AudioToSpectrogram(NeuralModule):
hop_length: length of hops/shifts of the sliding window
power: exponent for magnitude spectrogram. Default `None` will
return a complex-valued spectrogram
+ magnitude_power: Transform magnitude of the spectrogram as x^magnitude_power.
+ scale: Positive scaling of the spectrogram.
"""
- def __init__(self, fft_length: int, hop_length: int, power: Optional[float] = None):
+ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0):
if not HAVE_TORCHAUDIO:
logging.error('Could not import torchaudio. Some features might not work.')
@@ -726,12 +728,26 @@ def __init__(self, fft_length: int, hop_length: int, power: Optional[float] = No
raise ValueError(f'fft_length = {fft_length} must be divisible by 2')
self.stft = torchaudio.transforms.Spectrogram(
- n_fft=fft_length, hop_length=hop_length, power=power, pad_mode='constant'
+ n_fft=fft_length, hop_length=hop_length, power=None, pad_mode='constant'
)
# number of subbands
self.F = fft_length // 2 + 1
+ if magnitude_power <= 0:
+ raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}')
+ self.magnitude_power = magnitude_power
+
+ if scale <= 0:
+ raise ValueError(f'Scale needs to be positive: current value {scale}')
+ self.scale = scale
+
+ logging.debug('Initialized %s with:', self.__class__.__name__)
+ logging.debug('\tfft_length: %s', fft_length)
+ logging.debug('\thop_length: %s', hop_length)
+ logging.debug('\tmagnitude_power: %s', magnitude_power)
+ logging.debug('\tscale: %s', scale)
+
@property
def num_subbands(self) -> int:
return self.F
@@ -776,6 +792,14 @@ def forward(
with torch.cuda.amp.autocast(enabled=False):
output = self.stft(input.float())
+ if self.magnitude_power != 1:
+ # apply power on the magnitude
+ output = torch.pow(output.abs(), self.magnitude_power) * torch.exp(1j * output.angle())
+
+ if self.scale != 1:
+ # apply scaling of the coefficients
+ output = self.scale * output
+
if input_length is not None:
# Mask padded frames
output_length = self.get_output_length(input_length=input_length)
@@ -810,11 +834,11 @@ class SpectrogramToAudio(NeuralModule):
Args:
fft_length: length of FFT
hop_length: length of hops/shifts of the sliding window
- power: exponent for magnitude spectrogram. Default `None` will
- return a complex-valued spectrogram
+ magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power).
+ scale: Spectrogram will be scaled with 1/scale before the inverse transform.
"""
- def __init__(self, fft_length: int, hop_length: int):
+ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0):
if not HAVE_TORCHAUDIO:
logging.error('Could not import torchaudio. Some features might not work.')
@@ -834,6 +858,20 @@ def __init__(self, fft_length: int, hop_length: int):
self.F = fft_length // 2 + 1
+ if magnitude_power <= 0:
+ raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}')
+ self.magnitude_power = magnitude_power
+
+ if scale <= 0:
+ raise ValueError(f'Scale needs to be positive: current value {scale}')
+ self.scale = scale
+
+ logging.debug('Initialized %s with:', self.__class__.__name__)
+ logging.debug('\tfft_length: %s', fft_length)
+ logging.debug('\thop_length: %s', hop_length)
+ logging.debug('\tmagnitude_power: %s', magnitude_power)
+ logging.debug('\tscale: %s', scale)
+
@property
def num_subbands(self) -> int:
return self.F
@@ -875,7 +913,16 @@ def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = No
# iSTFT output (B, C, T)
with torch.cuda.amp.autocast(enabled=False):
- output = self.istft(input.cfloat())
+ output = input.cfloat()
+
+ if self.scale != 1:
+ # apply 1/scale on the coefficients
+ output = output / self.scale
+
+ if self.magnitude_power != 1:
+ # apply 1/power on the magnitude
+ output = torch.pow(output.abs(), 1 / self.magnitude_power) * torch.exp(1j * output.angle())
+ output = self.istft(output)
if input_length is not None:
# Mask padded samples
diff --git a/nemo/collections/asr/parts/submodules/diffusion.py b/nemo/collections/asr/parts/submodules/diffusion.py
new file mode 100644
index 000000000000..db3d30f49701
--- /dev/null
+++ b/nemo/collections/asr/parts/submodules/diffusion.py
@@ -0,0 +1,1310 @@
+# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from abc import ABC, abstractmethod
+from typing import Dict, Optional, Sequence, Tuple, Type
+
+import einops
+import einops.layers.torch
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from nemo.collections.common.parts.utils import activation_registry
+from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
+from nemo.core.classes import NeuralModule, typecheck
+from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType, VoidType
+from nemo.utils import logging
+
+__all__ = [
+ 'OrnsteinUhlenbeckVarianceExplodingSDE',
+ 'SpectrogramNoiseConditionalScoreNetworkPlusPlus',
+ 'NoiseConditionalScoreNetworkPlusPlus',
+ 'PredictorCorrectorSampler',
+]
+
+
+class StochasticDifferentialEquation(NeuralModule, ABC):
+ """Base class for stochastic differential equations.
+ """
+
+ def __init__(self, time_min: float, time_max: float, num_steps: int):
+ super().__init__()
+
+ # min and max time
+ if time_min <= 0:
+ raise ValueError(f'time_min should be positive, current value {time_min}')
+
+ if time_max <= time_min:
+ raise ValueError(f'time_max should be larger than time_min, current max {time_max} and min {time_min}')
+
+ self.time_min = time_min
+ self.time_max = time_max
+
+ # number of steps
+ if num_steps <= 0:
+ raise ValueError(f'num_steps needs to be positive: current value {num_steps}')
+
+ self.num_steps = num_steps
+
+ @property
+ def dt(self) -> float:
+ """Time step for this SDE.
+ This denotes the step size between `0` and `self.time_max` when using `self.num_steps`.
+ """
+ return self.time_max / self.num_steps
+
+ @property
+ def time_delta(self) -> float:
+ """Time range for this SDE.
+ """
+ return self.time_max - self.time_min
+
+ def generate_time(self, size: int, device: torch.device) -> torch.Tensor:
+ """Generate random time steps in the valid range.
+
+ Time steps are generated between `self.time_min` and `self.time_max`.
+
+ Args:
+ size: number of samples
+ device: device to use
+
+ Returns:
+ A tensor of floats with shape (size,)
+ """
+ time = torch.rand(size, device=device) * self.time_delta + self.time_min
+ return time
+
+ @abstractmethod
+ def coefficients(self, state: torch.Tensor, time: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ state: tensor of shape (B, C, D, T)
+ time: tensor of shape (B,)
+
+ Returns:
+ Tuple with drift and diffusion coefficients.
+ """
+ pass
+
+ @typecheck(
+ input_types={"prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()),},
+ output_types={"sample": NeuralType(('B', 'C', 'D', 'T'), VoidType()),},
+ )
+ @abstractmethod
+ def prior_sampling(self, prior_mean: torch.Tensor) -> torch.Tensor:
+ """Generate a sample from the prior distribution p_T.
+
+ Args:
+ prior_mean: Mean of the prior distribution
+
+ Returns:
+ A sample from the prior distribution.
+ """
+ pass
+
+ def discretize(
+ self, *, state: torch.Tensor, time: torch.Tensor, state_length: Optional[torch.Tensor] = None, **kwargs
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Assume we have the following SDE:
+
+ dx = drift(x, t) * dt + diffusion(x, t) * dwt
+
+ where `wt` is the standard Wiener process.
+
+ We assume the following discretization:
+
+ new_state = current_state + total_drift + total_diffusion * z_norm
+
+ where `z_norm` is sampled from normal distribution with zero mean and unit variance.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ time: current time of the process, shape (B,)
+ state_length: length of the valid time steps for each example in the batch, shape (B,)
+ **kwargs: other parameters
+
+ Returns:
+ Drift and diffusion.
+ """
+ # Get coefficients
+ drift_coefficient, diffusion_coefficient = self.coefficients(
+ state=state, time=time, state_length=state_length, **kwargs
+ )
+
+ # Discretized drift
+ drift = drift_coefficient * self.dt
+
+ # Note:
+ # Scale with sqrt(dt) because z_norm is sampled from a normal distribution with zero mean and
+ # unit variance and dwt is normally distributed with zero mean and variance dt
+ diffusion = diffusion_coefficient * np.sqrt(self.dt)
+
+ return drift, diffusion
+
+ @abstractmethod
+ def copy(self):
+ """Create a copy of this SDE.
+ """
+ pass
+
+ def __repr__(self):
+ desc = f'{self.__class__.__name__}(time_min={self.time_min}, time_max={self.time_max}, num_steps={self.num_steps})'
+ desc += f'\n\tdt: {self.dt}'
+ desc += f'\n\ttime_delta: {self.time_delta}'
+ return desc
+
+
+class OrnsteinUhlenbeckVarianceExplodingSDE(StochasticDifferentialEquation):
+ """This class implements the Ornstein-Uhlenbeck SDE with variance exploding noise schedule.
+
+ The SDE is given by:
+
+ dx = theta * (y - x) dt + g(t) dw
+
+ where `theta` is the stiffness parameter and `g(t)` is the diffusion coefficient:
+
+ g(t) = std_min * (std_max/std_min)^t * sqrt(2 * log(std_max/std_min))
+
+ References:
+ Richter et al., Speech Enhancement and Dereverberation with Diffusion-based Generative Models, Tr. ASLP 2023
+ """
+
+ def __init__(
+ self,
+ stiffness: float,
+ std_min: float,
+ std_max: float,
+ num_steps: int = 100,
+ time_min: float = 3e-2,
+ time_max: float = 1.0,
+ eps: float = 1e-8,
+ ):
+ super().__init__(time_min=time_min, time_max=time_max, num_steps=num_steps)
+
+ # Small regularization
+ if eps <= 0:
+ raise ValueError(f'eps should be positive, current value {eps}')
+ self.eps = eps
+
+ # stifness
+ self.stiffness = stiffness
+
+ # noise schedule
+ if std_min <= 0:
+ raise ValueError(f'std_min should be positive, current value {std_min}')
+
+ if std_max <= std_min:
+ raise ValueError(f'std_max should be larger than std_min, current max {std_max} and min {std_min}')
+
+ self.std_min = std_min
+ self.std_max = std_max
+
+ logging.debug('Initialized %s with', self.__class__.__name__)
+ logging.debug('\tstiffness: %s', self.stiffness)
+ logging.debug('\tstd_min: %s', self.std_min)
+ logging.debug('\tstd_max: %s', self.std_max)
+ logging.debug('\tnum_steps: %s', self.num_steps)
+ logging.debug('\ttime_min: %s', self.time_min)
+ logging.debug('\ttime_max: %s', self.time_max)
+ logging.debug('\teps: %s', self.eps)
+
+ @property
+ def std_ratio(self) -> float:
+ return self.std_max / (self.std_min + self.eps)
+
+ @property
+ def log_std_ratio(self) -> float:
+ return np.log(self.std_ratio + self.eps)
+
+ @typecheck(
+ input_types={
+ "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "time": NeuralType(tuple('B'), FloatType()),
+ },
+ output_types={"mean": NeuralType(('B', 'C', 'D', 'T'), FloatType()),},
+ )
+ def perturb_kernel_mean(self, state: torch.Tensor, prior_mean: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
+ """Return the mean of the perturbation kernel for this SDE.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ prior_mean: mean of the prior distribution
+ time: current time of the process, shape (B,)
+
+ Returns:
+ A tensor of shape (B, C, D, T)
+ """
+ # exponential weighting
+ weight = torch.exp(-self.stiffness * time)
+
+ # view as [B, C, D, T]
+ weight = weight.view(-1, 1, 1, 1)
+
+ # closed-form mean
+ mean = weight * state + (1 - weight) * prior_mean
+
+ return mean
+
+ @typecheck(
+ input_types={"time": NeuralType(tuple('B'), FloatType()),},
+ output_types={"std": NeuralType(tuple('B'), FloatType()),},
+ )
+ def perturb_kernel_std(self, time: torch.Tensor) -> torch.Tensor:
+ """Return the standard deviation of the perturbation kernel for this SDE.
+
+ Note that the standard deviation depends on the time and the noise schedule,
+ which is parametrized using `self.stiffness`, `self.std_min` and `self.std_max`.
+
+ Args:
+ time: current time of the process, shape (B,)
+
+ Returns:
+ A tensor of shape (B,)
+ """
+ var = (self.std_min ** 2) * self.log_std_ratio
+ var *= torch.pow(self.std_ratio, 2 * time) - torch.exp(-2 * self.stiffness * time)
+ var /= self.stiffness + self.log_std_ratio
+ std = torch.sqrt(var)
+ return std
+
+ @typecheck(
+ input_types={
+ "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "time": NeuralType(tuple('B'), FloatType()),
+ },
+ output_types={
+ "mean": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
+ "std": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
+ },
+ )
+ def perturb_kernel_params(self, state: torch.Tensor, prior_mean: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
+ """Return the mean and standard deviation of the perturbation kernel for this SDE.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ prior_mean: mean of the prior distribution
+ time: current time of the process, shape (B,)
+ """
+ assert torch.all(time <= self.time_max)
+ assert torch.all(time >= self.time_min)
+
+ # compute the mean
+ mean = self.perturb_kernel_mean(state=state, prior_mean=prior_mean, time=time)
+
+ # compute the standard deviation
+ std = self.perturb_kernel_std(time=time)
+ # view as [B, C, D, T]
+ std = std.view(-1, 1, 1, 1)
+
+ return mean, std
+
+ @typecheck(
+ input_types={
+ "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "time": NeuralType(tuple('B'), VoidType()),
+ "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "state_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ },
+ output_types={
+ "drift_coefficient": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
+ "diffusion_coefficient": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
+ },
+ )
+ def coefficients(
+ self,
+ state: torch.Tensor,
+ time: torch.Tensor,
+ prior_mean: torch.Tensor,
+ state_length: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute drift and diffusion coefficients for this SDE.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ time: current time of the process, shape (B,)
+ prior_mean: mean of the prior distribution
+ state_length: length of the valid time steps for each example in the batch
+
+ Returns:
+ Drift and diffusion coefficients.
+ """
+ # Drift coefficient
+ drift_coefficient = self.stiffness * (prior_mean - state)
+
+ # Diffusion coefficient
+ diffusion_coefficient = self.std_min * torch.pow(self.std_ratio, time) * np.sqrt(2 * self.log_std_ratio)
+ # View in the same shape as the state
+ diffusion_coefficient = diffusion_coefficient.view(-1, *([1] * (state.dim() - 1)))
+
+ if state_length is not None:
+ drift_coefficient = mask_sequence_tensor(drift_coefficient, state_length)
+ diffusion_coefficient = mask_sequence_tensor(diffusion_coefficient, state_length)
+
+ return drift_coefficient, diffusion_coefficient
+
+ def prior_sampling(self, prior_mean: torch.Tensor) -> torch.Tensor:
+ """Generate a sample from the prior distribution p_T.
+
+ Args:
+ prior_mean: Mean of the prior distribution
+ """
+ # Final time step for all samples in the batch
+ time = self.time_max * torch.ones(prior_mean.shape[0], device=prior_mean.device)
+
+ # Compute the std of the prior distribution
+ std = self.perturb_kernel_std(time=time)
+
+ # view as [B, C, D, T]
+ std = std.view(-1, 1, 1, 1)
+
+ # Generate a sample from a normal distribution centered at prior_mean
+ sample = prior_mean + torch.randn_like(prior_mean) * std
+
+ return sample
+
+ def copy(self):
+ return OrnsteinUhlenbeckVarianceExplodingSDE(
+ stiffness=self.stiffness,
+ std_min=self.std_min,
+ std_max=self.std_max,
+ num_steps=self.num_steps,
+ time_min=self.time_min,
+ time_max=self.time_max,
+ eps=self.eps,
+ )
+
+ def __repr__(self):
+ desc = f'{self.__class__.__name__}(stiffness={self.stiffness}, std_min={self.std_min}, std_max={self.std_max}, num_steps={self.num_steps}, time_min={self.time_min}, time_max={self.time_max}, eps={self.eps})'
+ desc += f'\n\tdt: {self.dt}'
+ desc += f'\n\ttime_delta: {self.time_delta}'
+ desc += f'\n\tstd_ratio: {self.std_ratio}'
+ desc += f'\n\tlog_std_ratio: {self.log_std_ratio}'
+
+ return desc
+
+
+class ReverseStochasticDifferentialEquation(StochasticDifferentialEquation):
+ def __init__(self, *, sde: Type[StochasticDifferentialEquation], score_estimator: Type[NeuralModule]):
+ """Use the forward SDE and a score estimator to define the reverse SDE.
+
+ Args:
+ sde: forward SDE
+ score_estimator: neural score estimator
+ """
+ super().__init__(time_min=sde.time_min, time_max=sde.time_max, num_steps=sde.num_steps)
+ self.score_estimator = score_estimator
+ self.forward_sde = sde
+
+ logging.debug('Initialized %s', self.__class__.__name__)
+
+ def coefficients(
+ self,
+ state: torch.Tensor,
+ time: torch.Tensor,
+ score_condition: Optional[torch.Tensor] = None,
+ state_length: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute drift and diffusion coefficients for the reverse SDE.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ time: current time of the process, shape (B,)
+ """
+ raise NotImplementedError('Coefficients not necessary for the reverse SDE.')
+
+ def prior_sampling(self, shape: torch.Size, device: torch.device) -> torch.Tensor:
+ """Prior sampling is not necessary for the reverse SDE.
+ """
+ raise NotImplementedError('Prior sampling not necessary for the reverse SDE.')
+
+ def discretize(
+ self,
+ *,
+ state: torch.Tensor,
+ time: torch.Tensor,
+ score_condition: Optional[torch.Tensor] = None,
+ state_length: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Discretize the reverse SDE.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ time: current time of the process, shape (B,)
+ score_condition: condition for the score estimator
+ state_length: length of the valid time steps for each example in the batch
+ **kwargs: other parameters for discretization of the forward SDE
+ """
+ # Drift and diffusion from the forward SDE
+ forward_drift, forward_diffusion = self.forward_sde.discretize(state=state, time=time, **kwargs)
+
+ # For input for the score estimator:
+ # - if no condition is provided, use the state
+ # - if a condition is provided, concatenate the state and the condition along the channel dimension
+ score_input = state if score_condition is None else torch.cat([state, score_condition], dim=1)
+
+ # Estimate score
+ score, _ = self.score_estimator(input=score_input, input_length=state_length, condition=time)
+
+ # Adjust drift
+ drift = forward_drift - forward_diffusion.pow(2) * score
+
+ # Adjust diffusion
+ diffusion = forward_diffusion
+
+ if state_length is not None:
+ drift = mask_sequence_tensor(drift, state_length)
+ diffusion = mask_sequence_tensor(diffusion, state_length)
+
+ return drift, diffusion
+
+ def copy(self):
+ return ReverseStochasticDifferentialEquation(sde=self.forward_sde.copy(), score_estimator=self.score_estimator)
+
+ def __repr__(self):
+ desc = f'{self.__class__.__name__}(sde={self.forward_sde}, score_estimator={self.score_estimator})'
+ return desc
+
+
+class SpectrogramNoiseConditionalScoreNetworkPlusPlus(NeuralModule):
+ """This model handles complex-valued inputs by stacking real and imaginary components.
+ Stacked tensor is processed using NCSN++ and the output is projected to generate real
+ and imaginary components of the output channels.
+
+ Args:
+ in_channels: number of input complex-valued channels
+ out_channels: number of output complex-valued channels
+ """
+
+ def __init__(self, *, in_channels: int = 1, out_channels: int = 1, **kwargs):
+ super().__init__()
+
+ # Number of input signals for this estimator
+ if in_channels < 1:
+ raise ValueError(
+ f'Number of input channels needs to be larger or equal to one, current value {in_channels}'
+ )
+
+ self.in_channels = in_channels
+
+ # Number of output signals for this estimator
+ if out_channels < 1:
+ raise ValueError(
+ f'Number of output channels needs to be larger or equal to one, current value {out_channels}'
+ )
+
+ self.out_channels = out_channels
+
+ # Instantiate noise conditional score network NCSN++
+ ncsnpp_params = kwargs.copy()
+ ncsnpp_params['in_channels'] = ncsnpp_params['out_channels'] = 2 * self.in_channels # stack real and imag
+ self.ncsnpp = NoiseConditionalScoreNetworkPlusPlus(**ncsnpp_params)
+
+ # Output projection to generate real and imaginary components of the output channels
+ self.output_projection = torch.nn.Conv2d(
+ in_channels=2 * self.in_channels, out_channels=2 * self.out_channels, kernel_size=1
+ )
+
+ logging.debug('Initialized %s with', self.__class__.__name__)
+ logging.debug('\tin_channels: %s', self.in_channels)
+ logging.debug('\tout_channels: %s', self.out_channels)
+
+ @property
+ def input_types(self) -> Dict[str, NeuralType]:
+ """Returns definitions of module output ports.
+ """
+ return {
+ "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
+ "input_length": NeuralType(('B',), LengthsType(), optional=True),
+ "condition": NeuralType(('B',), FloatType(), optional=True),
+ }
+
+ @property
+ def output_types(self) -> Dict[str, NeuralType]:
+ """Returns definitions of module output ports.
+ """
+ return {
+ "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
+ "output_length": NeuralType(('B',), LengthsType(), optional=True),
+ }
+
+ @typecheck()
+ def forward(self, input, input_length=None, condition=None):
+ # Stack real and imaginary components
+ B, C_in, D, T = input.shape
+
+ if C_in != self.in_channels:
+ raise RuntimeError(f'Unexpected input channel size {C_in}, expected {self.in_channels}')
+
+ # Stack real and imaginary parts
+ input_real_imag = torch.stack([input.real, input.imag], dim=2)
+ input = einops.rearrange(input_real_imag, 'B C RI F T -> B (C RI) F T')
+
+ # Process using NCSN++
+ output, output_length = self.ncsnpp(input=input, input_length=input_length, condition=condition)
+
+ # Output projection
+ output = self.output_projection(output)
+
+ # Convert to complex-valued signal
+ output = output.reshape(B, 2, self.out_channels, D, T)
+ # Move real/imag dimension to the end
+ output = output.permute(0, 2, 3, 4, 1)
+ output = torch.view_as_complex(output.contiguous())
+
+ return output, output_length
+
+
+class NoiseConditionalScoreNetworkPlusPlus(NeuralModule):
+ """Implementation of Noise Conditional Score Network (NCSN++) architecture.
+
+ References:
+ - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021
+ - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018
+ """
+
+ def __init__(
+ self,
+ nonlinearity: str = "swish",
+ in_channels: int = 2, # number of channels in the input image
+ out_channels: int = 2, # number of channels in the output image
+ channels: Sequence[int] = (128, 128, 256, 256, 256), # number of channels at start + at every resolution
+ num_res_blocks: int = 2,
+ num_resolutions: int = 4,
+ init_scale: float = 1e-5,
+ conditioned_on_time: bool = False,
+ fourier_embedding_scale: float = 16.0,
+ dropout_rate: float = 0.0,
+ pad_time_to: Optional[int] = None,
+ pad_dimension_to: Optional[int] = None,
+ **_,
+ ):
+ # Network topology is a flavor of UNet, example chart for num_resolutions=4
+ #
+ # 1: Image → Image/2 → Image/4 → Image/8
+ # ↓ ↓ ↓ ↓
+ # 2: Hidden → Hidden/2 → Hidden/4 → Hidden/8
+ # ↓ ↓ ↓ ↓
+ # 3: Hidden ← Hidden/2 ← Hidden/4 ← Hidden/8
+ # ↓ ↓ ↓ ↓
+ # 4: Image ← Image/2 ← Image/4 ← Image/8
+
+ # Horizontal arrows in (1) are downsampling
+ # Vertical arrows from (1) to (2) are channel upconversions
+ #
+ # Horizontal arrows in (2) are blocks with downsampling where necessary
+ # Horizontal arrows in (3) are blocks with upsampling where necessary
+ #
+ # Vertical arrows from (1) to (2) are downsampling and channel upconversioins
+ # Vertical arrows from (2) to (3) are sums connections (also with / sqrt(2))
+ # Vertical arrows from (3) to (4) are channel downconversions
+ # Horizontal arrows in (4) are upsampling and addition
+ super().__init__()
+
+ # same nonlinearity is used throughout the whole network
+ self.activation: torch.nn.Module = activation_registry[nonlinearity]()
+ self.init_scale: float = init_scale
+
+ self.downsample = torch.nn.Upsample(scale_factor=0.5, mode="bilinear")
+ self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear")
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.channels = channels
+ self.num_res_blocks = num_res_blocks
+ self.num_resolutions = num_resolutions
+ self.conditioned_on_time = conditioned_on_time
+
+ # padding setup
+ self.pad_time_to = pad_time_to or 2 ** self.num_resolutions
+ self.pad_dimension_to = pad_dimension_to or 2 ** self.num_resolutions
+
+ if self.conditioned_on_time:
+ self.time_embedding = torch.nn.Sequential(
+ GaussianFourierProjection(embedding_size=self.channels[0], scale=fourier_embedding_scale),
+ torch.nn.Linear(self.channels[0] * 2, self.channels[0] * 4),
+ self.activation,
+ torch.nn.Linear(self.channels[0] * 4, self.channels[0] * 4),
+ )
+
+ self.input_pyramid = torch.nn.ModuleList()
+ for ch in self.channels[:-1]:
+ self.input_pyramid.append(torch.nn.Conv2d(in_channels=self.in_channels, out_channels=ch, kernel_size=1))
+
+ # each block takes an image and outputs an image
+ # possibly changes number of channels
+ # output blocks ("reverse" path of the unet) reuse outputs of input blocks ("forward" path)
+ # so great care must be taken to in/out channels of each block
+ # resolutions are handled in `forward`
+ block_params = {
+ "activation": self.activation,
+ "dropout_rate": dropout_rate,
+ "init_scale": self.init_scale,
+ "diffusion_step_embedding_dim": channels[0] * 4 if self.conditioned_on_time else None,
+ }
+ self.input_blocks = torch.nn.ModuleList()
+ for in_ch, out_ch in zip(self.channels[:-1], self.channels[1:]):
+ for n in range(num_res_blocks):
+ block = ResnetBlockBigGANPlusPlus(in_ch=in_ch if n == 0 else out_ch, out_ch=out_ch, **block_params)
+ self.input_blocks.append(block)
+
+ self.output_blocks = torch.nn.ModuleList()
+ for in_ch, out_ch in zip(reversed(self.channels[1:]), reversed(self.channels[:-1])):
+ for n in reversed(range(num_res_blocks)):
+ block = ResnetBlockBigGANPlusPlus(in_ch=in_ch, out_ch=out_ch if n == 0 else in_ch, **block_params)
+ self.output_blocks.append(block)
+
+ self.projection_blocks = torch.nn.ModuleList()
+ for ch in self.channels[:-1]:
+ self.projection_blocks.append(torch.nn.Conv2d(ch, out_channels, kernel_size=1))
+
+ assert len(self.input_pyramid) == self.num_resolutions
+ assert len(self.input_blocks) == self.num_resolutions * self.num_res_blocks
+ assert len(self.output_blocks) == self.num_resolutions * self.num_res_blocks
+ assert len(self.projection_blocks) == self.num_resolutions
+
+ self.init_weights_()
+
+ logging.debug('Initialized %s with', self.__class__.__name__)
+ logging.debug('\tin_channels: %s', self.in_channels)
+ logging.debug('\tout_channels: %s', self.out_channels)
+ logging.debug('\tchannels: %s', self.channels)
+ logging.debug('\tnum_res_blocks: %s', self.num_res_blocks)
+ logging.debug('\tnum_resolutions: %s', self.num_resolutions)
+ logging.debug('\tconditioned_on_time: %s', self.conditioned_on_time)
+ logging.debug('\tpad_time_to: %s', self.pad_time_to)
+ logging.debug('\tpad_dimension_to: %s', self.pad_dimension_to)
+
+ def init_weights_(self):
+ for module in self.modules():
+ if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+
+ # torch.nn submodules with scaled init
+ for module in self.projection_blocks:
+ torch.nn.init.xavier_uniform_(module.weight, gain=self.init_scale)
+
+ # non-torch.nn submodules can have their own init schemes
+ for module in self.modules():
+ if module is self:
+ continue
+
+ if hasattr(module, "init_weights_"):
+ module.init_weights_()
+
+ @typecheck(
+ input_types={"input": NeuralType(('B', 'C', 'D', 'T')),},
+ output_types={"output": NeuralType(('B', 'C', 'D', 'T')),},
+ )
+ def pad_input(self, input: torch.Tensor) -> torch.Tensor:
+ """Pad input tensor to match the required dimensions across `T` and `D`.
+ """
+ *_, D, T = input.shape
+ output = input
+
+ # padding across time
+ if T % self.pad_time_to != 0:
+ output = F.pad(output, (0, self.pad_time_to - T % self.pad_time_to))
+
+ # padding across dimension
+ if D % self.pad_dimension_to != 0:
+ output = F.pad(output, (0, 0, 0, self.pad_dimension_to - D % self.pad_dimension_to))
+
+ return output
+
+ @property
+ def input_types(self) -> Dict[str, NeuralType]:
+ """Returns definitions of module output ports.
+ """
+ return {
+ "input": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "input_length": NeuralType(('B',), LengthsType(), optional=True),
+ "condition": NeuralType(('B',), FloatType(), optional=True),
+ }
+
+ @property
+ def output_types(self) -> Dict[str, NeuralType]:
+ """Returns definitions of module output ports.
+ """
+ return {
+ "output": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "output_length": NeuralType(('B',), LengthsType(), optional=True),
+ }
+
+ @typecheck()
+ def forward(
+ self, *, input: torch.Tensor, input_length: Optional[torch.Tensor], condition: Optional[torch.Tensor] = None
+ ):
+ """Forward pass of the model.
+
+ Args:
+ input: input tensor, shjae (B, C, D, T)
+ input_length: length of the valid time steps for each example in the batch, shape (B,)
+ condition: scalar condition (time) for the model, will be embedded using `self.time_embedding`
+ """
+ assert input.shape[1] == self.in_channels
+
+ # apply padding at the input
+ *_, D, T = input.shape
+ input = self.pad_input(input=input)
+
+ if input_length is None:
+ # assume all time frames are valid
+ input_length = torch.LongTensor([input.shape[-1]] * input.shape[0]).to(input.device)
+
+ lengths = input_length
+
+ if condition is not None:
+ if len(condition.shape) != 1:
+ raise ValueError(
+ f"Expected conditon to be a 1-dim tensor, got a {len(condition.shape)}-dim tensor of shape {tuple(condition.shape)}"
+ )
+ if condition.shape[0] != input.shape[0]:
+ raise ValueError(
+ f"Condition {tuple(condition.shape)} and input {tuple(input.shape)} should match along the batch dimension"
+ )
+
+ condition = self.time_embedding(torch.log(condition))
+
+ # downsample and project input image to add later in the downsampling path
+ pyramid = [input]
+ for resolution_num in range(self.num_resolutions - 1):
+ pyramid.append(self.downsample(pyramid[-1]))
+ pyramid = [block(image) for image, block in zip(pyramid, self.input_pyramid)]
+
+ # downsampling path
+ history = []
+ hidden = torch.zeros_like(pyramid[0])
+ input_blocks = iter(self.input_blocks)
+ for resolution_num, image in enumerate(pyramid):
+ hidden = (hidden + image) / math.sqrt(2.0)
+ hidden = mask_sequence_tensor(hidden, lengths)
+
+ for _ in range(self.num_res_blocks):
+ hidden = next(input_blocks)(hidden, condition)
+ hidden = mask_sequence_tensor(hidden, lengths)
+ history.append(hidden)
+
+ final_resolution = resolution_num == self.num_resolutions - 1
+ if not final_resolution:
+ hidden = self.downsample(hidden)
+ lengths = (lengths / 2).ceil().long()
+
+ # upsampling path
+ to_project = []
+ for residual, block in zip(reversed(history), self.output_blocks):
+ if hidden.shape != residual.shape:
+ to_project.append(hidden)
+ hidden = self.upsample(hidden)
+ lengths = (lengths * 2).long()
+
+ hidden = (hidden + residual) / math.sqrt(2.0)
+ hidden = block(hidden, condition)
+ hidden = mask_sequence_tensor(hidden, lengths)
+
+ to_project.append(hidden)
+
+ # projecting to images
+ images = []
+ for tensor, projection in zip(to_project, reversed(self.projection_blocks)):
+ image = projection(tensor)
+ images.append(F.interpolate(image, size=input.shape[-2:])) # TODO write this loop using self.upsample
+
+ result = sum(images)
+
+ assert result.shape[-2:] == input.shape[-2:]
+
+ # remove padding
+ result = result[:, :, :D, :T]
+ return result, input_length
+
+
+class GaussianFourierProjection(NeuralModule):
+ """Gaussian Fourier embeddings for input scalars.
+
+ The input scalars are typically time or noise levels.
+ """
+
+ def __init__(self, embedding_size: int = 256, scale: float = 1.0):
+ super().__init__()
+ self.W = torch.nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ @property
+ def input_types(self) -> Dict[str, NeuralType]:
+ """Returns definitions of module output ports.
+ """
+ return {
+ "input": NeuralType(('B',), FloatType()),
+ }
+
+ @property
+ def output_types(self) -> Dict[str, NeuralType]:
+ """Returns definitions of module output ports.
+ """
+ return {
+ "output": NeuralType(('B', 'D'), VoidType()),
+ }
+
+ def forward(self, input):
+ x_proj = input[:, None] * self.W[None, :] * 2 * math.pi
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+
+
+class ResnetBlockBigGANPlusPlus(torch.nn.Module):
+ """Implementation of a ResNet block for the BigGAN model.
+
+ References:
+ - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021
+ - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018
+ """
+
+ def __init__(
+ self,
+ activation: torch.nn.Module,
+ in_ch: int,
+ out_ch: int,
+ diffusion_step_embedding_dim: Optional[int] = None,
+ init_scale: float = 1e-5,
+ dropout_rate: float = 0.1,
+ in_num_groups: Optional[int] = None,
+ out_num_groups: Optional[int] = None,
+ eps: float = 1e-6,
+ ):
+ """
+ Args:
+ activation (torch.nn.Module): activation layer (ReLU, SiLU, etc)
+ in_ch (int): number of channels in the input image
+ out_ch (int, optional): number of channels in the output image
+ diffusion_step_embedding_dim (int, optional): dimension of diffusion timestep embedding. Defaults to None (no embedding).
+ dropout_rate (float, optional): dropout rate. Defaults to 0.1.
+ init_scale (float, optional): scaling for weight initialization. Defaults to 0.0.
+ in_num_groups (int, optional): num_groups in the first GroupNorm. Defaults to min(in_ch // 4, 32)
+ out_num_groups (int, optional): num_groups in the second GroupNorm. Defaults to min(out_ch // 4, 32)
+ eps (float, optional): eps parameter of GroupNorms. Defaults to 1e-6.
+ """
+ super().__init__()
+ in_num_groups = in_num_groups or min(in_ch // 4, 32)
+ out_num_groups = out_num_groups or min(out_ch // 4, 32)
+
+ self.init_scale = init_scale
+
+ self.input_block = torch.nn.Sequential(
+ torch.nn.GroupNorm(num_groups=in_num_groups, num_channels=in_ch, eps=eps), activation,
+ )
+
+ self.middle_conv = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
+ if diffusion_step_embedding_dim is not None:
+ self.diffusion_step_projection = torch.nn.Sequential(
+ activation,
+ torch.nn.Linear(diffusion_step_embedding_dim, out_ch),
+ einops.layers.torch.Rearrange("batch dim -> batch dim 1 1"),
+ )
+
+ self.output_block = torch.nn.Sequential(
+ torch.nn.GroupNorm(num_groups=out_num_groups, num_channels=out_ch, eps=eps),
+ activation,
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1),
+ )
+
+ if in_ch != out_ch:
+ self.residual_projection = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1)
+
+ self.act = activation
+ self.in_ch = in_ch
+ self.out_ch = out_ch
+
+ self.init_weights_()
+
+ def init_weights_(self):
+ """Weight initialization
+ """
+ for module in self.modules():
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+
+ # a single Conv2d is initialized with gain
+ torch.nn.init.xavier_uniform_(self.output_block[-1].weight, gain=self.init_scale)
+
+ def forward(self, x: torch.Tensor, diffusion_time_embedding: Optional[torch.Tensor] = None):
+ """Forward pass of the model.
+
+ Args:
+ x: input tensor
+ diffusion_time_embedding: embedding of the diffusion time step
+
+ Returns:
+ Output tensor
+ """
+ h = self.input_block(x)
+ h = self.middle_conv(h)
+
+ if diffusion_time_embedding is not None:
+ h = h + self.diffusion_step_projection(diffusion_time_embedding)
+
+ h = self.output_block(h)
+
+ if x.shape != h.shape: # matching number of channels
+ x = self.residual_projection(x)
+ return (x + h) / math.sqrt(2.0)
+
+
+class PredictorCorrectorSampler(NeuralModule):
+ """Predictor-Corrector sampler for the reverse SDE.
+
+ Args:
+ sde: forward SDE
+ score_estimator: neural score estimator
+ predictor: predictor for the reverse process
+ corrector: corrector for the reverse process
+ num_steps: number of time steps for the reverse process
+ num_corrector_steps: number of corrector steps
+ time_max: maximum time
+ time_min: minimum time
+ snr: SNR for Annealed Langevin Dynamics
+ output_type: type of the output ('state' for the final state, or 'mean' for the mean of the final state)
+
+ References:
+ - Song et al., Score-based generative modeling through stochastic differential equations, 2021
+ """
+
+ def __init__(
+ self,
+ sde,
+ score_estimator,
+ predictor: str = 'reverse_diffusion',
+ corrector: str = 'annealed_langevin_dynamics',
+ num_steps: int = 50,
+ num_corrector_steps: int = 1,
+ time_max: Optional[float] = None,
+ time_min: Optional[float] = None,
+ snr: float = 0.5,
+ output_type: str = 'mean',
+ ):
+ super().__init__()
+ # Create a copy of SDE
+ self.sde = sde.copy()
+
+ # Update SDE parameters for sampling
+ if time_max is not None:
+ self.sde.time_max = time_max
+ logging.info('sde.time_max set to: %s', self.sde.time_max)
+
+ if time_min is not None:
+ self.sde.time_min = time_min
+ logging.info('sde.time_min set to: %s', self.sde.time_min)
+
+ self.sde.num_steps = num_steps
+ logging.info('sde.num_steps set to: %s', self.sde.num_steps)
+
+ # Update local values
+ self.time_max = self.sde.time_max
+ self.time_min = self.sde.time_min
+ self.num_steps = self.sde.num_steps
+
+ # Predictor setup
+ if predictor == 'reverse_diffusion':
+ self.predictor = ReverseDiffusionPredictor(sde=self.sde, score_estimator=score_estimator)
+ else:
+ raise RuntimeError(f'Unexpected predictor: {predictor}')
+
+ # Corrector setup
+ if corrector == 'annealed_langevin_dynamics':
+ self.corrector = AnnealedLangevinDynamics(
+ sde=self.sde, score_estimator=score_estimator, snr=snr, num_steps=num_corrector_steps
+ )
+ else:
+ raise RuntimeError(f'Unexpected corrector: {corrector}')
+
+ if output_type not in ['mean', 'state']:
+ raise ValueError(f'Unexpected output type: {output_type}')
+ self.output_type = output_type
+
+ logging.debug('Initialized %s with', self.__class__.__name__)
+ logging.debug('\tpredictor: %s', predictor)
+ logging.debug('\tcorrector: %s', corrector)
+ logging.debug('\tnum_steps: %s', self.num_steps)
+ logging.debug('\ttime_min: %s', self.time_min)
+ logging.debug('\ttime_max: %s', self.time_max)
+ logging.debug('\tnum_corrector_steps: %s', num_corrector_steps)
+ logging.debug('\tsnr: %s', snr)
+ logging.debug('\toutput_type: %s', self.output_type)
+
+ @typecheck(
+ input_types={
+ "prior_mean": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
+ "score_condition": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType(), optional=True),
+ "state_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ },
+ output_types={
+ "sample": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
+ "state_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ },
+ )
+ @torch.inference_mode()
+ def forward(
+ self, prior_mean: torch.Tensor, score_condition: torch.Tensor, state_length: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Takes prior (noisy) mean and generates a sample by solving the reverse SDE.
+
+ Args:
+ prior_mean: mean for the prior distribution, e.g., noisy observation
+ score_condition: conditioning for the score estimator
+ state_length: length of the valid time steps for each example in the batch
+
+ Returns:
+ Generated `sample` and the corresponding `sample_length`.
+ """
+ # Sample from the prior distribution
+ state = self.sde.prior_sampling(prior_mean=prior_mean)
+
+ if state_length is not None:
+ state = mask_sequence_tensor(state, state_length)
+
+ # Time steps for evaluation
+ time_steps = torch.linspace(self.time_max, self.time_min, self.num_steps, device=state.device)
+
+ # Sampling
+ for t in time_steps:
+ # time steps for the whole batch
+ time = t * torch.ones(state.shape[0], device=state.device)
+
+ # corrector step
+ state, _ = self.corrector(
+ state=state, time=time, score_condition=score_condition, state_length=state_length
+ )
+
+ # predictor step
+ state, state_mean = self.predictor(
+ state=state,
+ time=time,
+ score_condition=score_condition,
+ prior_mean=prior_mean,
+ state_length=state_length,
+ )
+
+ # Final output
+ if self.output_type == 'state':
+ sample = state
+ elif self.output_type == 'mean':
+ sample = state_mean
+ else:
+ raise RuntimeError(f'Unexpected output type: {self.output_type}')
+
+ if state_length is not None:
+ sample = mask_sequence_tensor(sample, state_length)
+
+ return sample, state_length
+
+
+class Predictor(torch.nn.Module, ABC):
+ """Predictor for the reverse process.
+
+ Args:
+ sde: forward SDE
+ score_estimator: neural score estimator
+ """
+
+ def __init__(self, sde, score_estimator):
+ super().__init__()
+ self.reverse_sde = ReverseStochasticDifferentialEquation(sde=sde, score_estimator=score_estimator)
+
+ @abstractmethod
+ @torch.inference_mode()
+ def forward(
+ self,
+ *,
+ state: torch.Tensor,
+ time: torch.Tensor,
+ score_condition: Optional[torch.Tensor] = None,
+ state_length: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ """Predict the next state of the reverse process.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ time: current time of the process, shape (B,)
+ score_condition: conditioning for the score estimator
+ state_length: length of the valid time steps for each example in the batch
+
+ Returns:
+ New state and mean.
+ """
+ pass
+
+
+class ReverseDiffusionPredictor(Predictor):
+ """Predict the next state of the reverse process using the reverse diffusion process.
+
+ Args:
+ sde: forward SDE
+ score_estimator: neural score estimator
+ """
+
+ def __init__(self, sde, score_estimator):
+ super().__init__(sde=sde, score_estimator=score_estimator)
+
+ @torch.inference_mode()
+ def forward(self, *, state, time, score_condition=None, state_length=None, **kwargs):
+ """Predict the next state of the reverse process using the reverse diffusion process.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ time: current time of the process, shape (B,)
+ score_condition: conditioning for the score estimator
+ state_length: length of the valid time steps for each example in the batch
+
+ Returns:
+ New state and mean of the diffusion process.
+ """
+ drift, diffusion = self.reverse_sde.discretize(
+ state=state, time=time, score_condition=score_condition, state_length=state_length, **kwargs
+ )
+
+ # Generate a random sample from a standard normal distribution
+ z_norm = torch.randn_like(state)
+
+ # Compute the mean of the next state
+ mean = state - drift
+
+ # Compute new state by sampling
+ new_state = mean + diffusion * z_norm
+
+ if state_length is not None:
+ new_state = mask_sequence_tensor(new_state, state_length)
+ mean = mask_sequence_tensor(mean, state_length)
+
+ return new_state, mean
+
+
+class Corrector(NeuralModule, ABC):
+ """Corrector for the reverse process.
+
+ Args:
+ sde: forward SDE
+ score_estimator: neural score estimator
+ snr: SNR for Annealed Langevin Dynamics
+ num_steps: number of steps for the corrector
+ """
+
+ def __init__(
+ self,
+ sde: Type[StochasticDifferentialEquation],
+ score_estimator: Type[NeuralModule],
+ snr: float,
+ num_steps: int,
+ ):
+ super().__init__()
+ self.sde = sde
+ self.score_estimator = score_estimator
+ self.snr = snr
+ self.num_steps = num_steps
+
+ logging.debug('Initialized %s with', self.__class__.__name__)
+ logging.debug('\tsnr: %s', snr)
+ logging.debug('\tnum_steps: %s', num_steps)
+
+ @abstractmethod
+ @typecheck(
+ input_types={
+ "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
+ "time": NeuralType(tuple('B'), FloatType()),
+ "score_condition": NeuralType(('B', 'C', 'D', 'T'), VoidType(), optional=True),
+ "state_length": NeuralType(tuple('B'), LengthsType(), optional=True),
+ },
+ output_types={"state": NeuralType(('B', 'C', 'D', 'T'), VoidType()),},
+ )
+ @torch.inference_mode()
+ def forward(self, state, time, score_condition=None, state_length=None):
+ """
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ time: current time of the process, shape (B,)
+ score_condition: conditioning for the score estimator
+ state_length: length of the valid time steps for each example in the batch
+
+ Returns:
+ New state and mean.
+ """
+ pass
+
+
+class AnnealedLangevinDynamics(Corrector):
+ """Annealed Langevin Dynamics for the reverse process.
+
+ References:
+ - Song et al., Score-based generative modeling through stochastic differential equations, 2021
+ """
+
+ def __init__(self, sde, **kwargs):
+ if not isinstance(sde, OrnsteinUhlenbeckVarianceExplodingSDE):
+ raise ValueError(f'Expected an instance of OrnsteinUhlenbeckVarianceExplodingSDE, got {type(sde)}')
+ super().__init__(sde=sde, **kwargs)
+
+ @torch.inference_mode()
+ def forward(self, state, time, score_condition=None, state_length=None):
+ """Correct the state using Annealed Langevin Dynamics.
+
+ Args:
+ state: current state of the process, shape (B, C, D, T)
+ time: current time of the process, shape (B,)
+ score_condition: conditioning for the score estimator
+ state_length: length of the valid time steps for each example in the batch
+
+ Returns:
+ New state and mean of the diffusion process.
+
+ References:
+ Alg. 4 in http://arxiv.org/abs/2011.13456
+ """
+ # Compute the standard deviation of the diffusion process
+ std = self.sde.perturb_kernel_std(time=time)
+ # View as [B, 1, 1, 1]
+ std = std.view(-1, *([1] * (state.dim() - 1)))
+
+ for i in range(self.num_steps):
+ # prepare input for the score estimator, concatenate conditioning along the channel dimension
+ score_input = state if score_condition is None else torch.cat([state, score_condition], dim=1)
+
+ # calculate the score
+ score, _ = self.score_estimator(input=score_input, input_length=state_length, condition=time)
+
+ # generate a sample from a standard normal distribution
+ z_norm = torch.randn_like(state)
+
+ # compute the step size
+ # note: this is slightly different than in the paper, where std = ||z_norm||_2 / ||score||_2
+ step_size = 2 * (self.snr * std).pow(2)
+
+ # update the mean
+ mean = state + step_size * score
+
+ # update the state
+ state = mean + z_norm * torch.sqrt(step_size * 2)
+
+ if state_length is not None:
+ state = mask_sequence_tensor(state, state_length)
+ mean = mask_sequence_tensor(mean, state_length)
+
+ return state, mean
diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py
index 7023f57652b5..6ea4314ab71f 100644
--- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py
+++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py
@@ -1674,7 +1674,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# megatron_amp_O2 is not yet supported in diffusion models
self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)
- if self.cfg.precision in ['16', 16, 'bf16']:
+ if self.megatron_amp_O2 and self.cfg.precision in ['16', 16, 'bf16']:
self.model_parallel_config.enable_autocast = False
if not hasattr(self.cfg.unet_config, 'unet_precision') or not '16' in str(
self.cfg.unet_config.unet_precision
diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py
index c92980d904f6..3fcab2127f4f 100644
--- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py
+++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
+import os
from inspect import isfunction
import torch
@@ -21,6 +22,13 @@
from torch import einsum, nn
from torch._dynamo import disable
+if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1":
+ from nemo.gn_native import GroupNormNormlization as GroupNorm
+else:
+ from apex.contrib.group_norm import GroupNorm
+
+from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP
+
from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
AdapterName,
@@ -96,13 +104,19 @@ def forward(self, x):
class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=False):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
- project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
- self.net = nn.Sequential(project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))
+ if use_te:
+ activation = 'gelu' if not glu else 'geglu'
+ # TODO: more parameters to be confirmed, dropout, seq_length
+ self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,)
+ else:
+ norm = nn.LayerNorm(dim)
+ project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
+ self.net = nn.Sequential(norm, project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
@@ -225,10 +239,15 @@ def __init__(
dropout=0.0,
use_flash_attention=False,
lora_network_alpha=None,
+ use_te=False,
):
super().__init__()
self.inner_dim = dim_head * heads
+ if context_dim is None:
+ self.is_self_attn = True
+ else:
+ self.is_self_attn = False # cross-attention
context_dim = default(context_dim, query_dim)
# make attention part be aware of self-attention/cross-attention
self.context_dim = context_dim
@@ -238,10 +257,19 @@ def __init__(
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
+ self.use_te = use_te
+ if use_te:
+ return_layernorm_output = True if self.is_self_attn else False
+ self.norm_to_q = LayerNormLinear(
+ query_dim, self.inner_dim, bias=False, return_layernorm_output=return_layernorm_output
+ )
+ else:
+ self.norm = nn.LayerNorm(query_dim)
+ self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False)
+
self.to_out = nn.Sequential(
LinearWrapper(self.inner_dim, query_dim, lora_network_alpha=lora_network_alpha), nn.Dropout(dropout)
)
@@ -262,8 +290,18 @@ def forward(self, x, context=None, mask=None, additional_tokens=None, n_times_cr
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
- q = self.to_q(x)
- context = default(context, x)
+ if self.use_te:
+ q_out = self.norm_to_q(x)
+ if self.is_self_attn:
+ q, ln_out = q_out
+ context = default(context, ln_out)
+ else:
+ q = q_out
+ context = default(context, x)
+ else:
+ x = self.norm(x)
+ q = self.to_q(x)
+ context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
@@ -351,6 +389,7 @@ def __init__(
use_flash_attention=False,
disable_self_attn=False,
lora_network_alpha=None,
+ use_te=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
@@ -362,8 +401,9 @@ def __init__(
use_flash_attention=use_flash_attention,
context_dim=context_dim if self.disable_self_attn else None,
lora_network_alpha=lora_network_alpha,
+ use_te=use_te,
) # is a self-attention
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, use_te=use_te)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
@@ -372,10 +412,8 @@ def __init__(
dropout=dropout,
use_flash_attention=use_flash_attention,
lora_network_alpha=lora_network_alpha,
+ use_te=use_te,
) # is self-attn if context is none
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
- self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
@@ -397,15 +435,15 @@ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_at
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
x = (
self.attn1(
- self.norm1(x),
+ x,
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
)
+ x
)
- x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
- x = self.ff(self.norm3(x)) + x
+ x = self.attn2(x, context=context, additional_tokens=additional_tokens) + x
+ x = self.ff(x) + x
return x
@@ -431,6 +469,7 @@ def __init__(
use_checkpoint=False,
use_flash_attention=False,
lora_network_alpha=None,
+ use_te=False,
):
super().__init__()
logging.info(
@@ -473,6 +512,7 @@ def __init__(
use_flash_attention=use_flash_attention,
disable_self_attn=disable_self_attn,
lora_network_alpha=lora_network_alpha,
+ use_te=use_te,
)
for d in range(depth)
]
diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py
index 5ff0f6aa8a8a..b610f921a22a 100644
--- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py
+++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py
@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
+import os
+import re
from abc import abstractmethod
from collections.abc import Iterable
+from contextlib import nullcontext
from functools import partial
from typing import Iterable
@@ -22,6 +25,9 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
+
+# FP8 related import
+import transformer_engine
from apex.contrib.group_norm import GroupNorm
from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer
@@ -62,6 +68,34 @@ def convert_module_to_fp32(module, enable_norm_layers=False):
convert_module_to_dtype(module, torch.float32, enable_norm_layers)
+def convert_module_to_fp8(model):
+ def _set_module(model, submodule_key, module):
+ tokens = submodule_key.split('.')
+ sub_tokens = tokens[:-1]
+ cur_mod = model
+ for s in sub_tokens:
+ cur_mod = getattr(cur_mod, s)
+ setattr(cur_mod, tokens[-1], module)
+
+ import copy
+
+ from transformer_engine.pytorch.module import Linear as te_Linear
+
+ for n, v in model.named_modules():
+ if isinstance(v, torch.nn.Linear):
+ # if n in ['class_embed', 'bbox_embed.layers.0', 'bbox_embed.layers.1', 'bbox_embed.layers.2']: continue
+ logging.info(f'[INFO] Replace Linear: {n}, weight: {v.weight.shape}')
+ if v.bias is None:
+ is_bias = False
+ else:
+ is_bias = True
+ newlinear = te_Linear(v.in_features, v.out_features, bias=is_bias)
+ newlinear.weight = copy.deepcopy(v.weight)
+ if v.bias is not None:
+ newlinear.bias = copy.deepcopy(v.bias)
+ _set_module(model, n, newlinear)
+
+
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
@@ -553,6 +587,7 @@ def __init__(
unet_precision: str = "fp32",
lora_network_alpha=None,
timesteps=1000,
+ use_te_fp8: bool = False,
):
super().__init__()
from omegaconf.listconfig import ListConfig
@@ -663,6 +698,7 @@ def __init__(
input_block_chans = [model_channels]
ch = model_channels
ds = 1
+ self.use_te_fp8 = use_te_fp8
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
@@ -713,6 +749,7 @@ def __init__(
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
lora_network_alpha=lora_network_alpha,
+ use_te=self.use_te_fp8,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -778,6 +815,7 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
+ use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
),
ResBlock(
@@ -844,6 +882,7 @@ def __init__(
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
lora_network_alpha=lora_network_alpha,
+ use_te=self.use_te_fp8,
)
)
if level and i == self.num_res_blocks[level]:
@@ -899,6 +938,34 @@ def __init__(
self.convert_to_fp16()
elif unet_precision == 'fp16':
self.convert_to_fp16(enable_norm_layers=True)
+ elif self.use_te_fp8:
+ assert unet_precision != 'fp16', "fp8 training can't work with fp16 O2 amp recipe"
+ convert_module_to_fp8(self)
+
+ fp8_margin = int(os.getenv("FP8_MARGIN", '0'))
+ fp8_interval = int(os.getenv("FP8_INTERVAL", '1'))
+ fp8_format = os.getenv("FP8_FORMAT", "hybrid")
+ fp8_amax_history_len = int(os.getenv("FP8_HISTORY_LEN", '1024'))
+ fp8_amax_compute_algo = os.getenv("FP8_COMPUTE_ALGO", 'max')
+ fp8_wgrad = os.getenv("FP8_WGRAD", '1') == '1'
+
+ fp8_format_dict = {
+ 'hybrid': transformer_engine.common.recipe.Format.HYBRID,
+ 'e4m3': transformer_engine.common.recipe.Format.E4M3,
+ }
+ fp8_format = fp8_format_dict[fp8_format]
+
+ self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
+ margin=fp8_margin,
+ interval=fp8_interval,
+ fp8_format=fp8_format,
+ amax_history_len=fp8_amax_history_len,
+ amax_compute_algo=fp8_amax_compute_algo,
+ override_linear_precision=(False, False, not fp8_wgrad),
+ )
+ old_state_dict = self.state_dict()
+ new_state_dict = self.te_fp8_key_mapping(old_state_dict)
+ self.load_state_dict(new_state_dict, strict=False)
self.unet_precision = unet_precision
@@ -1000,8 +1067,65 @@ def _sdxl_embedding_mapping(self, sdxl_dict):
res_dict[new_key_] = value_
return res_dict
+ def _legacy_unet_ckpt_mapping(self, unet_dict):
+ new_dict = {}
+ key_map = {
+ 'transformer_blocks.0.norm1.weight': 'transformer_blocks.0.attn1.norm.weight',
+ 'transformer_blocks.0.norm1.bias': 'transformer_blocks.0.attn1.norm.bias',
+ 'transformer_blocks.0.norm2.weight': 'transformer_blocks.0.attn2.norm.weight',
+ 'transformer_blocks.0.norm2.bias': 'transformer_blocks.0.attn2.norm.bias',
+ 'transformer_blocks.0.norm3.weight': 'transformer_blocks.0.ff.net.0.weight',
+ 'transformer_blocks.0.norm3.bias': 'transformer_blocks.0.ff.net.0.bias',
+ 'transformer_blocks.0.ff.net.0.proj.weight': 'transformer_blocks.0.ff.net.1.proj.weight',
+ 'transformer_blocks.0.ff.net.0.proj.bias': 'transformer_blocks.0.ff.net.1.proj.bias',
+ 'transformer_blocks.0.ff.net.2.weight': 'transformer_blocks.0.ff.net.3.weight',
+ 'transformer_blocks.0.ff.net.2.bias': 'transformer_blocks.0.ff.net.3.bias',
+ }
+
+ pattern = re.compile(r'(input_blocks|output_blocks)\.[\d\w]+\.[\d\w]+\.')
+ pattern_middle_block = re.compile(r'middle_block\.[\d\w]+\.')
+ for old_key, value in unet_dict.items():
+ match = pattern.match(old_key)
+ match_middle = pattern_middle_block.match(old_key)
+ if match or match_middle:
+ prefix = match.group(0) if match else match_middle.group(0)
+ suffix = old_key.split('.', 3)[-1] if match else old_key.split('.', 2)[-1]
+ if suffix in key_map:
+ new_key = prefix + key_map[suffix]
+ new_dict[new_key] = value
+ else:
+ new_dict[old_key] = value
+ else:
+ new_dict[old_key] = value
+
+ return new_dict
+
+ def te_fp8_key_mapping(self, unet_dict):
+ new_state_dict = {}
+ for key in unet_dict.keys():
+ if 'extra_state' in key:
+ continue
+
+ ### LayerNormLinear
+ # norm_to_q.layer_norm_{weight|bias} -> norm.{weight|bias}
+ # norm_to_q.weight -> to_q.weight
+ new_key = key.replace('attn1.norm.', 'attn1.norm_to_q.layer_norm_')
+ new_key = new_key.replace('attn1.to_q.weight', 'attn1.norm_to_q.weight',)
+ new_key = new_key.replace('attn2.norm.', 'attn2.norm_to_q.layer_norm_')
+ new_key = new_key.replace('attn2.to_q.weight', 'attn2.norm_to_q.weight',)
+
+ ### LayerNormMLP
+ # ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias}
+ # ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias}
+ # ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias}
+ new_key = new_key.replace('ff.net.0.', 'ff.net.layer_norm_')
+ new_key = new_key.replace('ff.net.1.proj.', 'ff.net.fc1_')
+ new_key = new_key.replace('ff.net.3.', 'ff.net.fc2_')
+
+ new_state_dict[new_key] = unet_dict[key]
+ return new_state_dict
+
def _state_key_mapping(self, state_dict: dict):
- import re
res_dict = {}
input_dict = {}
@@ -1027,13 +1151,7 @@ def _state_key_mapping(self, state_dict: dict):
mid_dict = self._mid_blocks_mapping(mid_dict)
other_dict = self._other_blocks_mapping(other_dict)
sdxl_dict = self._sdxl_embedding_mapping(sdxl_dict)
- # key_list = state_dict.keys()
- # key_str = " ".join(key_list)
- # for key_, val_ in state_dict.items():
- # key_ = key_.replace("down_blocks", "input_blocks")\
- # .replace("up_blocks", 'output_blocks')
- # res_dict[key_] = val_
res_dict.update(input_dict)
res_dict.update(output_dict)
res_dict.update(mid_dict)
@@ -1046,6 +1164,7 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from
state_dict = self._strip_unet_key_prefix(state_dict)
if not from_NeMo:
state_dict = self._state_key_mapping(state_dict)
+ state_dict = self._legacy_unet_ckpt_mapping(state_dict)
model_state_dict = self.state_dict()
loaded_keys = [k for k in state_dict.keys()]
@@ -1151,7 +1270,7 @@ def convert_to_fp16(self, enable_norm_layers=False):
"""
self.apply(lambda module: convert_module_to_fp16(module=module, enable_norm_layers=enable_norm_layers))
- def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ def _forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
@@ -1170,7 +1289,6 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
-
if self.unet_precision == "fp16-mixed" or self.unet_precision == "fp16":
x = x.type(torch.float16)
if context is not None:
@@ -1197,6 +1315,13 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
else:
return self.out(h)
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ with transformer_engine.pytorch.fp8_autocast(
+ enabled=self.use_te_fp8, fp8_recipe=self.fp8_recipe,
+ ) if self.use_te_fp8 else nullcontext():
+ out = self._forward(x, timesteps, context, y, **kwargs)
+ return out
+
class EncoderUNetModel(nn.Module):
"""
diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
index a5e886f3b479..16ded8e2c682 100644
--- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
+++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
@@ -37,6 +37,8 @@
LoraDenseAttentionAdapterConfig,
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
+ LoraUnfusedHto4HAdapterConfig,
+ LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
PromptEncoderAdapterConfig,
@@ -67,7 +69,12 @@ def mcore_register_adapters(self):
Setup NeMo LoRA or IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types(
- [LoraKQVAdapterConfig._target_, LoraDenseAttentionAdapterConfig._target_, InfusedAdapterConfig._target_]
+ [
+ LoraUnfusedKQVAdapterConfig._target_,
+ LoraKQVAdapterConfig._target_,
+ LoraDenseAttentionAdapterConfig._target_,
+ InfusedAdapterConfig._target_,
+ ]
)
self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp
if (
@@ -102,12 +109,20 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
# LoRA logic
if self.is_adapter_available():
+ lora_adapter = None
lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER)
+ lora_unfused_kqv_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_KQV_ADAPTER)
if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']:
+ lora_adapter = lora_kqv_adapter
+ if lora_unfused_kqv_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_KQV_ADAPTER]['enabled']:
+ assert lora_adapter is None, "Expected only one of lora_kqv_adapter or lora_unfused_kqv_adapter"
+ lora_adapter = lora_unfused_kqv_adapter
+
+ if lora_adapter:
if layernorm_output is not None:
- lora_mixed_qkv = lora_kqv_adapter(layernorm_output)
+ lora_mixed_qkv = lora_adapter(layernorm_output)
else:
- lora_mixed_qkv = lora_kqv_adapter(hidden_states)
+ lora_mixed_qkv = lora_adapter(hidden_states)
mixed_qkv = mixed_qkv + lora_mixed_qkv
@@ -251,7 +266,12 @@ def mcore_register_adapters(self):
Setup NeMo IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types(
- [LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_]
+ [
+ LoraUnfusedHto4HAdapterConfig._target_,
+ LoraHto4HAdapterConfig._target_,
+ Lora4HtoHAdapterConfig._target_,
+ MLPInfusedAdapterConfig._target_,
+ ]
) # only self attn (packed qkv) for now
self.linear_fc1.return_layernorm_output = True # need layernorm output for lora mlp
if (
@@ -274,9 +294,17 @@ def forward(self, hidden_states):
# LoRA logic
if self.is_adapter_available():
- lora_linear_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
- if lora_linear_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
- lora_output = lora_linear_fc1_adapter(layernorm_output)
+ lora_adapter = None
+ lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
+ lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER)
+ if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
+ lora_adapter = lora_fc1_adapter
+ if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']:
+ assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER"
+ lora_adapter = lora_unfused_fc1_adapter
+
+ if lora_adapter:
+ lora_output = lora_adapter(layernorm_output)
intermediate_parallel = intermediate_parallel + lora_output
if self.config.bias_activation_fusion:
diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py
index 5037bb1b3634..2a5372d11ab5 100644
--- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py
+++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py
@@ -75,11 +75,13 @@ class AdapterName(str, enum.Enum):
POST_ATTN_ADAPTER = 'adapter_2'
PTUNING_ADAPTER = "ptuning_adapter"
LORA_KQV_ADAPTER = "lora_kqv_adapter"
+ LORA_UNFUSED_KQV_ADAPTER = "lora_unfused_kqv_adapter"
LORA_KV_ADAPTER = "lora_kv_adapter"
LORA_Q_ADAPTER = "lora_q_adapter"
MM_LINEAR_ADAPTER = "mm_linear_adapter"
LORA_DENSE_ATTENTION_ADAPTER = "lora_dense_attention_adapter"
LORA_Hto4H_ADAPTER = "lora_hto4h_adapter"
+ LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter"
LORA_4HtoH_ADAPTER = "lora_4htoh_adapter"
MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter"
PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter"
@@ -457,6 +459,183 @@ class Lora4HtoHAdapterConfig(ParallelLinearAdapterConfig):
input_is_parallel: bool = True
+class LoraUnfusedHto4HAdapter(nn.Module, AdapterModuleUtil):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ dim: int,
+ activation: str = 'swish',
+ norm_position: Optional[str] = 'post',
+ norm_type: Optional[str] = 'mixedfusedlayernorm',
+ column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise.
+ row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise.
+ gather_output: bool = True,
+ input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers
+ dropout: float = 0.0,
+ model_parallel_config: Optional[ModelParallelConfig] = None,
+ alpha: float | None = None,
+ dropout_position: str = 'post',
+ a2a_experimental: bool = False, # TODO: should rename this or make it a default feature
+ **kwargs,
+ ):
+ super().__init__()
+ self.gate_adapter = ParallelLinearAdapter(
+ in_features,
+ out_features // 2,
+ dim,
+ activation,
+ norm_position,
+ norm_type,
+ column_init_method,
+ row_init_method,
+ gather_output,
+ input_is_parallel,
+ dropout,
+ model_parallel_config,
+ alpha,
+ dropout_position,
+ a2a_experimental,
+ )
+ self.up_adapter = ParallelLinearAdapter(
+ in_features,
+ out_features // 2,
+ dim,
+ activation,
+ norm_position,
+ norm_type,
+ column_init_method,
+ row_init_method,
+ gather_output,
+ input_is_parallel,
+ dropout,
+ model_parallel_config,
+ alpha,
+ dropout_position,
+ a2a_experimental,
+ )
+
+ def forward(self, x):
+ gate_x = self.gate_adapter(x)
+ up_x = self.up_adapter(x)
+ x = torch.concat([gate_x, up_x], dim=2)
+ return x
+
+
+@dataclass
+class LoraUnfusedHto4HAdapterConfig(ParallelLinearAdapterConfig):
+ _target_: str = "{0}.{1}".format(LoraUnfusedHto4HAdapter.__module__, LoraUnfusedHto4HAdapter.__name__)
+
+
+class LoraUnfusedKQVAdapter(nn.Module, AdapterModuleUtil):
+ def __init__(
+ self,
+ in_features: int,
+ dim: int,
+ activation: str = 'swish',
+ norm_position: Optional[str] = 'post',
+ norm_type: Optional[str] = 'mixedfusedlayernorm',
+ column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise.
+ row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise.
+ gather_output: bool = True,
+ input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers
+ dropout: float = 0.0,
+ model_parallel_config: Optional[ModelParallelConfig] = None,
+ alpha: float | None = None,
+ dropout_position: str = 'post',
+ a2a_experimental: bool = False, # TODO: should rename this or make it a default feature
+ num_query_groups: Optional[int] = None,
+ kv_channels: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__()
+ if num_query_groups is not None and kv_channels is not None:
+ out_features = kv_channels * num_query_groups
+ else:
+ out_features = in_features
+
+ self.q_adapter = ParallelLinearAdapter(
+ in_features,
+ in_features,
+ dim,
+ activation,
+ norm_position,
+ norm_type,
+ column_init_method,
+ row_init_method,
+ gather_output,
+ input_is_parallel,
+ dropout,
+ model_parallel_config,
+ alpha,
+ dropout_position,
+ a2a_experimental,
+ )
+
+ self.k_adapter = ParallelLinearAdapter(
+ in_features,
+ out_features,
+ dim,
+ activation,
+ norm_position,
+ norm_type,
+ column_init_method,
+ row_init_method,
+ gather_output,
+ input_is_parallel,
+ dropout,
+ model_parallel_config,
+ alpha,
+ dropout_position,
+ a2a_experimental,
+ )
+ self.v_adapter = ParallelLinearAdapter(
+ in_features,
+ out_features,
+ dim,
+ activation,
+ norm_position,
+ norm_type,
+ column_init_method,
+ row_init_method,
+ gather_output,
+ input_is_parallel,
+ dropout,
+ model_parallel_config,
+ alpha,
+ dropout_position,
+ a2a_experimental,
+ )
+
+ def forward(self, x):
+ qx = self.q_adapter(x)
+ kx = self.k_adapter(x)
+ vx = self.v_adapter(x)
+ x = torch.concat([qx, kx, vx], dim=2)
+ return x
+
+
+@dataclass
+class LoraUnfusedKQVAdapterConfig(AdapterConfig):
+ in_features: int
+ dim: int
+ activation: str = 'swish'
+ norm_position: Optional[str] = 'post'
+ norm_type: Optional[str] = 'mixedfusedlayernorm'
+ column_init_method: str = 'xavier'
+ row_init_method: str = 'zero'
+ gather_output: bool = True
+ input_is_parallel: bool = False
+ dropout: float = 0.0
+ dropout_position: str = 'post'
+ alpha: float | None = None
+ network_alpha: int | None = None
+ a2a_experimental: bool = False
+ num_query_groups: Optional[int] = None
+ kv_channels: Optional[int] = None
+ _target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__)
+
+
class PromptEncoderAdapter(nn.Module, AdapterModuleUtil):
"""
The Tensor Parallel MLP prompt encoder network that is used to generate the virtual
diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py
index a6f68f0666b5..0a030759fe9b 100644
--- a/nemo/collections/nlp/parts/nlp_overrides.py
+++ b/nemo/collections/nlp/parts/nlp_overrides.py
@@ -73,6 +73,7 @@
try:
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
+
from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam
HAVE_APEX = True
@@ -1057,6 +1058,31 @@ def should_process(key):
new_state_dict[key_] = state_dict[key_]
state_dict = new_state_dict
+ if conf.get('unet_config') and conf.get('unet_config').get('use_te_fp8') == False:
+ # Mapping potential fp8 ckpt to fp16 model
+ # remove _extra_state in fp8 if there is.
+ new_state_dict = {}
+ for key in state_dict.keys():
+ if 'extra_state' in key:
+ continue
+
+ ### LayerNormLinear
+ # norm_to_q.layer_norm_{weight|bias} -> norm.{weight|bias}
+ # norm_to_q.weight -> to_q.weight
+ new_key = key.replace('norm_to_q.layer_norm_', 'norm.')
+ new_key = new_key.replace('norm_to_q.weight', 'to_q.weight')
+
+ ### LayerNormMLP
+ # ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias}
+ # ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias}
+ # ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias}
+ new_key = new_key.replace('ff.net.layer_norm_', 'ff.net.0.')
+ new_key = new_key.replace('ff.net.fc1_', 'ff.net.1.proj.')
+ new_key = new_key.replace('ff.net.fc2_', 'ff.net.3.')
+
+ new_state_dict[new_key] = state_dict[key]
+ state_dict = new_state_dict
+
return state_dict
def _load_state_dict_from_disk(self, model_weights, map_location=None):
diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py
index 63caa409b218..47d5167d630e 100644
--- a/nemo/collections/nlp/parts/peft_config.py
+++ b/nemo/collections/nlp/parts/peft_config.py
@@ -36,6 +36,8 @@
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraKQVAdapterWeightTyingConfig,
+ LoraUnfusedHto4HAdapterConfig,
+ LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
ParallelLinearAdapterWeightTyingConfig,
@@ -132,11 +134,26 @@ def __init__(self, cfg):
for module in target_modules:
if module == PEFT_MODULE_MAP["qkv_module"]:
- adapter_cfg = self._create_lora_config(
- cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, LoraKQVAdapterConfig
- )
- name_key_to_cfg[AdapterName.LORA_KQV_ADAPTER] = adapter_cfg
- name_key_to_mcore_mixins[AdapterName.LORA_KQV_ADAPTER] = [("self_attention", MCoreSelfAttentionMixin)]
+ if lora_cfg.get("variant", "nemo") == "canonical":
+ _adapter_name = AdapterName.LORA_UNFUSED_KQV_ADAPTER
+ _adapter_cfg_cls = LoraUnfusedKQVAdapterConfig
+ adapter_cfg = self._create_lora_config(
+ cfg,
+ lora_cfg,
+ cfg.hidden_size,
+ qkv_projection_size,
+ _adapter_cfg_cls,
+ num_query_groups=num_query_groups,
+ kv_channels=kv_channels,
+ )
+ else:
+ _adapter_name = AdapterName.LORA_KQV_ADAPTER
+ _adapter_cfg_cls = LoraKQVAdapterConfig
+ adapter_cfg = self._create_lora_config(
+ cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, _adapter_cfg_cls
+ )
+ name_key_to_cfg[_adapter_name] = adapter_cfg
+ name_key_to_mcore_mixins[_adapter_name] = [("self_attention", MCoreSelfAttentionMixin)]
elif module == PEFT_MODULE_MAP["dense_module"]:
adapter_cfg = self._create_lora_config(
@@ -149,11 +166,18 @@ def __init__(self, cfg):
elif module == PEFT_MODULE_MAP["hto4h_module"]:
hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size
+ if lora_cfg.get("variant", "nemo") == "canonical":
+ _adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER
+ _adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig
+ else:
+ _adapter_name = AdapterName.LORA_Hto4H_ADAPTER
+ _adapter_cfg_cls = LoraHto4HAdapterConfig
+
adapter_cfg = self._create_lora_config(
- cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, LoraHto4HAdapterConfig
+ cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls
)
- name_key_to_cfg[AdapterName.LORA_Hto4H_ADAPTER] = adapter_cfg
- name_key_to_mcore_mixins[AdapterName.LORA_Hto4H_ADAPTER] = [("mlp", MCoreMLPMixin)]
+ name_key_to_cfg[_adapter_name] = adapter_cfg
+ name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]
elif module == PEFT_MODULE_MAP["4htoh_module"]:
adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig
@@ -170,7 +194,9 @@ def __init__(self, cfg):
self.name_key_to_mcore_mixins = name_key_to_mcore_mixins
super().__init__(lora_cfg, name_key_to_cfg)
- def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls):
+ def _create_lora_config(
+ self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls, num_query_groups=None, kv_channels=None
+ ):
config_args = {
"in_features": in_features,
"out_features": out_features,
@@ -187,6 +213,12 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_
"a2a_experimental": lora_cfg.get("a2a_experimental", False),
}
+ if adapter_cfg_cls == LoraUnfusedKQVAdapterConfig:
+ assert num_query_groups is not None, "num_query_groups must be provided for canonical Lora"
+ assert kv_channels is not None, "kv_channels must be provided for canonical Lora"
+ config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels})
+ config_args.pop("out_features")
+
if lora_cfg.weight_tying:
position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None)
if position_embedding_strategy is None:
diff --git a/nemo/export/quantize/quantizer.py b/nemo/export/quantize/quantizer.py
index 2663f8fe9bac..783f47a08e79 100644
--- a/nemo/export/quantize/quantizer.py
+++ b/nemo/export/quantize/quantizer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
import tarfile
from contextlib import nullcontext
from typing import List, Optional
@@ -21,7 +20,6 @@
import torch.distributed as dist
from megatron.core import parallel_state
from megatron.core.transformer.module import Float16Module
-from megatron.training.utils import unwrap_model
from omegaconf import OmegaConf
from omegaconf.omegaconf import DictConfig, open_dict
from pytorch_lightning.trainer.trainer import Trainer
@@ -31,7 +29,7 @@
from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.utils import logging
from nemo.utils.distributed import temporary_directory
-from nemo.utils.model_utils import load_config, save_artifacts
+from nemo.utils.model_utils import load_config, save_artifacts, unwrap_model
try:
import ammo.torch.quantization as atq
diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py
index 95d1bc414625..f4eefd39a9ea 100644
--- a/nemo/utils/model_utils.py
+++ b/nemo/utils/model_utils.py
@@ -24,7 +24,7 @@
from enum import Enum
from functools import lru_cache
from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Type, Union
import wrapt
@@ -92,6 +92,24 @@ def load_config(model_file: str) -> DictConfig:
return model_config
+def unwrap_model(model, module_instances: Union[Type, Tuple[Type]]):
+ """Unwrap model from wrapper classes like Float16Module, for example."""
+
+ # TODO: Import this from megatron.core once moved there from megatron.training.
+ return_list = True
+ if not isinstance(model, list):
+ model = [model]
+ return_list = False
+ unwrapped_model = []
+ for model_module in model:
+ while isinstance(model_module, module_instances):
+ model_module = model_module.module
+ unwrapped_model.append(model_module)
+ if not return_list:
+ return unwrapped_model[0]
+ return unwrapped_model
+
+
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt
index b7863714eb2d..30e839fd2ca8 100644
--- a/requirements/requirements_asr.txt
+++ b/requirements/requirements_asr.txt
@@ -1,5 +1,6 @@
braceexpand
editdistance
+einops
g2p_en
ipywidgets
jiwer
diff --git a/scripts/checkpoint_converters/convert_zarr_to_torch_dist.py b/scripts/checkpoint_converters/convert_zarr_to_torch_dist.py
new file mode 100644
index 000000000000..29b56aa706fa
--- /dev/null
+++ b/scripts/checkpoint_converters/convert_zarr_to_torch_dist.py
@@ -0,0 +1,193 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""
+Conversion script to convert zarr checkpoints into torch distributed checkpoint.
+ Example to run this conversion script:
+ python -m torch.distributed.launch --nproc_per_node= * \
+ megatron_zarr_ckpt_to_torch_dist.py \
+ --model_type \
+ --checkpoint_folder \
+ --checkpoint_name \
+ --path_to_save \
+ --tensor_model_parallel_size \
+ --pipeline_model_parallel_size \
+ --hparams_file \
+ --gpus_per_node
+"""
+
+import os
+from argparse import ArgumentParser
+
+import torch
+from megatron.core import parallel_state
+from omegaconf import OmegaConf, open_dict
+
+from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
+from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
+from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
+from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
+from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
+from nemo.utils import AppState, logging
+from nemo.utils.distributed import initialize_distributed
+
+
+def get_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--checkpoint_folder",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints",
+ )
+ parser.add_argument(
+ "--checkpoint_name",
+ type=str,
+ default=None,
+ required=True,
+ help="Name of checkpoint to be used. Ex: megatron_gpt--val_loss=0.14-step=20-consumed_samples=160.0-last",
+ )
+
+ parser.add_argument(
+ "--hparams_file",
+ type=str,
+ default=None,
+ required=True,
+ help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml",
+ )
+ parser.add_argument("--path_to_save", type=str, default=None, required=True, help="Path to output ckpt files.")
+ parser.add_argument(
+ "--save_to_nemo", action="store_true", help="If passed, output will be written as .nemo file.",
+ )
+ parser.add_argument("--gpus_per_node", type=int, required=True, default=None)
+ parser.add_argument("--tensor_model_parallel_size", type=int, required=True, default=None)
+ parser.add_argument("--pipeline_model_parallel_size", type=int, required=True, default=None)
+ parser.add_argument(
+ "--pipeline_model_parallel_split_rank",
+ type=int,
+ required=False,
+ default=None,
+ help="If pipeline parallel size > 1, this is the rank at which the encoder ends and the decoder begins.",
+ )
+ parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1))
+ parser.add_argument("--cluster_type", required=False, default=None, help="Whether on BCP platform")
+ parser.add_argument(
+ "--precision",
+ type=str,
+ required=False,
+ default='bf16-mixed',
+ choices=['32-true', '16-mixed', 'bf16-mixed'],
+ help="Precision value for the trainer that matches with precision of the ckpt",
+ )
+
+ parser.add_argument(
+ "--model_type", type=str, required=True, default="gpt", choices=["gpt", "sft", "bert"],
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+def convert(local_rank, rank, world_size, args):
+
+ app_state = AppState()
+ app_state.data_parallel_rank = 0
+ num_nodes = world_size // args.gpus_per_node
+
+ cfg = {
+ 'trainer': {
+ 'devices': args.gpus_per_node,
+ 'num_nodes': num_nodes,
+ 'accelerator': 'gpu',
+ 'precision': args.precision,
+ },
+ 'model': {
+ 'native_amp_init_scale': 2 ** 32,
+ 'native_amp_growth_interval': 1000,
+ 'hysteresis': 2,
+ 'gradient_as_bucket_view': True,
+ },
+ 'cluster_type': args.cluster_type,
+ }
+ cfg = OmegaConf.create(cfg)
+
+ # Set precision None after precision plugins are created as PTL >= 2.1 does not allow both
+ # precision plugins and precision to exist
+ cfg.trainer.precision = None
+
+ trainer = MegatronTrainerBuilder(cfg).create_trainer()
+
+ app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
+ app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
+ app_state.pipeline_model_parallel_split_rank = None
+
+ app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size
+
+ parallel_state.initialize_model_parallel(
+ tensor_model_parallel_size=app_state.tensor_model_parallel_size,
+ pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
+ pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
+ )
+
+ app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
+ app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
+
+ # check for distributed checkpoint
+ checkpoint_path = os.path.join(args.checkpoint_folder, args.checkpoint_name)
+
+ logging.info(
+ f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}'
+ )
+
+ if args.model_type == "gpt":
+ model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
+ elif args.model_type == "sft":
+ model = MegatronGPTSFTModel.load_from_checkpoint(
+ checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
+ )
+ # we force the target for the loaded model to have the correct target
+ # because the hparams.yaml sometimes contains MegatronGPTModel as the target.
+ with open_dict(model.cfg):
+ model.cfg.target = f"{MegatronGPTSFTModel.__module__}.{MegatronGPTSFTModel.__name__}"
+ elif args.model_type == 'bert':
+ model = MegatronBertModel.load_from_checkpoint(
+ checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
+ )
+
+ with open_dict(model.cfg):
+ model.cfg.torch_distributed_checkpoint = True
+
+ model._save_restore_connector = NLPSaveRestoreConnector()
+ save_file_path = args.path_to_save
+ if not args.save_to_nemo:
+ # With --save_to_nemo, save_to_path is expected to be a directory.
+ # Adding a dummy model filename here conforms with SaveRestoreConnector's convention.
+ model._save_restore_connector.pack_nemo_file = False
+ save_file_path = os.path.join(save_file_path, 'model.nemo')
+
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+ model.save_to(save_file_path)
+
+ logging.info(f'NeMo model saved to: {args.path_to_save}')
+
+
+if __name__ == '__main__':
+ args = get_args()
+
+ local_rank, rank, world_size = initialize_distributed(args)
+
+ convert(local_rank, rank, world_size, args)
diff --git a/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py b/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py
new file mode 100644
index 000000000000..f2974aca1642
--- /dev/null
+++ b/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py
@@ -0,0 +1,212 @@
+#!/usr/bin/env
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Convert nemo style (fused) lora checkpoint to canonical (unfused) lora checkpoint.
+Currently supports TP=PP=1 only.
+
+Example usage:
+python scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py \
+ --lora_path nemo_style_lora_model.nemo \
+ --output_path ./canonical_style_lora_model.nemo
+
+"""
+import tempfile
+from argparse import ArgumentParser
+from typing import Dict
+
+import torch
+from omegaconf import OmegaConf, open_dict
+from scripts.nlp_language_modeling.merge_lora_weights.merge import replace_number_add_offset
+
+from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
+
+
+def rename_keys(key):
+ new_keys = []
+ if "lora_kqv_adapter" in key:
+ new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.q_adapter."))
+ new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.k_adapter."))
+ new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.v_adapter."))
+ elif "lora_hto4h_adapter" in key:
+ new_keys.append(key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.gate_adapter."))
+ new_keys.append(key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.up_adapter."))
+ return new_keys
+
+
+def reformat_module_names_to_hf(tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ new_tensors = dict()
+ for module_name, module_weight in tensors.items():
+ # map linear_in and linear_out to lora_a/lora_b counterparts
+ new_module_name = "base_model." + module_name.replace("linear_in", "lora_A").replace("linear_out", "lora_B")
+
+ # map target modules to their vLLM/HF counterparts
+ new_module_name = new_module_name.replace("q_adapter", "q_proj")
+ new_module_name = new_module_name.replace("k_adapter", "k_proj")
+ new_module_name = new_module_name.replace("v_adapter", "v_proj")
+ new_module_name = new_module_name.replace("lora_dense_attention_adapter", "o_proj")
+ new_module_name = new_module_name.replace("lora_4htoh_adapter", "down_proj")
+ new_module_name = new_module_name.replace("gate_adapter", "gate_proj")
+ new_module_name = new_module_name.replace("up_adapter", "up_proj")
+
+ # map other parts of the module names to fit vLLM/huggingface
+ new_module_name = new_module_name.replace(".adapter_layer", "")
+ new_module_name = new_module_name.replace(".lora_unfused_kqv_proj", "")
+ new_module_name = new_module_name.replace(".lora_unfused_hto4h_adapter", "")
+ new_module_name = new_module_name.replace("self_attention", "self_attn")
+ new_module_name = new_module_name.replace("decoder", "model")
+
+ new_tensors[new_module_name] = module_weight
+ return new_tensors
+
+
+def convert_hto4h(lora_weights, lora_config):
+ assert len(lora_weights) == 1, "Only single TP supported for now"
+ keys_to_update = []
+ for key in lora_weights[0].keys():
+ if "lora_hto4h_adapter" in key:
+ keys_to_update.append(key)
+
+ for key in keys_to_update:
+ if "linear_in" in key:
+ for new_key in rename_keys(key):
+ lora_weights[0][new_key] = lora_weights[0][key]
+ print(new_key, lora_weights[0][new_key].shape)
+ elif "linear_out" in key:
+ for idx, new_key in enumerate(rename_keys(key)):
+ orginal_shape = lora_weights[0][key].shape[0]
+ lora_weights[0][new_key] = lora_weights[0][key][
+ idx * (orginal_shape // 2) : (idx + 1) * (orginal_shape // 2)
+ ]
+ print(new_key, lora_weights[0][new_key].shape)
+
+ lora_weights[0].pop(key)
+ return lora_weights
+
+
+def convert_qkv(lora_weights, lora_model_cfg):
+ assert len(lora_weights) == 1, "Only single TP supported for now"
+ if (
+ lora_model_cfg.get("num_query_groups", lora_model_cfg.num_attention_heads)
+ != lora_model_cfg.num_attention_heads
+ ):
+ kv_channels = int(lora_model_cfg.hidden_size / lora_model_cfg.num_attention_heads)
+ kv_size = int(lora_model_cfg.num_query_groups * kv_channels)
+ else:
+ kv_size = int(lora_model_cfg.hidden_size)
+ q_size = lora_model_cfg.hidden_size
+ k_size, v_size = kv_size, kv_size
+
+ keys_to_update = []
+ for key in lora_weights[0].keys():
+ if "lora_kqv_adapter" in key:
+ keys_to_update.append(key)
+
+ for key in keys_to_update:
+ if "linear_in" in key:
+ for new_key in rename_keys(key):
+ lora_weights[0][new_key] = lora_weights[0][key]
+ print(new_key, lora_weights[0][new_key].shape)
+ elif "linear_out" in key:
+ srt = 0
+ for new_key, size in zip(rename_keys(key), [q_size, k_size, v_size]):
+ lora_weights[0][new_key] = lora_weights[0][key][srt : srt + size]
+ print(new_key, lora_weights[0][new_key].shape)
+ srt = srt + size
+
+ lora_weights[0].pop(key)
+ return lora_weights
+
+
+def convert_lora(lora_nemo, save_path, hf_format=False):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ NLPSaveRestoreConnector._unpack_nemo_file(lora_nemo, tmpdir)
+ config_file = f"{tmpdir}/model_config.yaml"
+ lora_config = OmegaConf.load(config_file)
+ tp_size = lora_config.tensor_model_parallel_size
+ pp_size = lora_config.pipeline_model_parallel_size
+
+ lora_state_dict = [{}] * tp_size
+
+ for pp in range(pp_size):
+ for tp in range(tp_size):
+ if tp_size == 1:
+ ckpt_file = f"{tmpdir}/model_weights.ckpt"
+ elif pp_size == 1:
+ ckpt_file = f"{tmpdir}/mp_rank_{tp:02d}/model_weights.ckpt"
+ else:
+ ckpt_file = f"{tmpdir}/tp_rank_{tp:02d}_pp_rank_{pp:03d}/model_weights.ckpt"
+
+ l = torch.load(ckpt_file, map_location=torch.device('cpu'))
+ if pp == 0:
+ lora_state_dict[tp] = l
+ else:
+ # calculate layer offset
+ layer_offset = lora_config.num_layers // pp_size * pp
+ for key, value in l.items():
+ new_key = replace_number_add_offset(key, layer_offset)
+ lora_state_dict[tp][new_key] = value
+
+ with open_dict(lora_config):
+ lora_config.peft.lora_tuning.variant = "canonical"
+ with open(f"{tmpdir}/model_config.yaml", "w") as f:
+ OmegaConf.save(lora_config, f)
+ lora_state_dict = convert_qkv(lora_state_dict, lora_config)
+ lora_state_dict = convert_hto4h(lora_state_dict, lora_config)
+ # TODO: currently suport tp=1
+ lora_state_dict = lora_state_dict[0]
+ if hf_format:
+ lora_state_dict = reformat_module_names_to_hf(lora_state_dict)
+ torch.save(lora_state_dict, f"{save_path}/model_weights_hf_formatted.pt")
+ else:
+ torch.save(lora_state_dict, f"{tmpdir}/model_weights.ckpt")
+ NLPSaveRestoreConnector._make_nemo_file_from_folder(save_path, tmpdir)
+
+ return lora_state_dict, lora_config
+
+
+def fix_for_O2(state_dict):
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if "model.module." not in k:
+ new_state_dict[k.replace('model.', 'model.module.')] = v
+ return new_state_dict
+
+
+def get_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--lora_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to NeMo style (fused) lora checkpoint in .nemo file format",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to save the canonical (unfused) lora .nemo file.",
+ )
+ parser.add_argument("--hf_format", action='store_true', help="saves tensors in huggingface naming format.")
+ parser.add_argument("--precision", type=str, default="16", help="Model precision")
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = get_args()
+ convert_lora(args.lora_path, args.output_path, args.hf_format)
diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py
index 946acb614f11..a2e39628e4cb 100644
--- a/tests/collections/asr/test_asr_datasets.py
+++ b/tests/collections/asr/test_asr_datasets.py
@@ -809,6 +809,39 @@ def test_list_to_multichannel(self, num_channels, num_targets):
# Check the list is converted back to the original signal
assert (ASRAudioProcessor.list_to_multichannel(target_list) == golden_target).all()
+ @pytest.mark.unit
+ @pytest.mark.parametrize('num_channels', [1, 2])
+ def test_processor_process_audio(self, num_channels):
+ """Test signal normalization in process_audio.
+ """
+ num_samples = 1000
+ num_examples = 30
+
+ signals = ['input_signal', 'target_signal', 'reference_signal']
+
+ for normalization_signal in [None] + signals:
+ # Create processor
+ processor = ASRAudioProcessor(
+ sample_rate=16000, random_offset=False, normalization_signal=normalization_signal
+ )
+
+ # Generate random signals
+ for n in range(num_examples):
+ example = {signal: torch.randn(num_channels, num_samples) for signal in signals}
+ processed_example = processor.process_audio(example)
+
+ # Expected scale
+ if normalization_signal:
+ scale = 1.0 / (example[normalization_signal].abs().max() + processor.eps)
+ else:
+ scale = 1.0
+
+ # Make sure all signals are scaled as expected
+ for signal in signals:
+ assert torch.allclose(
+ processed_example[signal], example[signal] * scale
+ ), f'Failed example {n} signal {signal}'
+
@pytest.mark.unit
def test_audio_collate_fn(self):
"""Test `_audio_collate_fn`
diff --git a/tests/collections/asr/test_asr_losses.py b/tests/collections/asr/test_asr_losses.py
index e09fd71e0892..e050e7cc07c3 100644
--- a/tests/collections/asr/test_asr_losses.py
+++ b/tests/collections/asr/test_asr_losses.py
@@ -17,7 +17,9 @@
import torch
from nemo.collections.asr.losses.audio_losses import (
+ MSELoss,
SDRLoss,
+ calculate_mse_batch,
calculate_sdr_batch,
convolution_invariant_target,
scale_invariant_target,
@@ -271,7 +273,7 @@ def test_sdr_binary_mask(self, num_channels):
estimate = target + noise
# Limit calculation to masked samples
- mask = _rng.integers(low=0, high=2, size=(batch_size, max_num_samples))
+ mask = _rng.integers(low=0, high=2, size=(batch_size, num_channels, max_num_samples))
# Tensors for testing the loss
tensor_estimate = torch.tensor(estimate)
@@ -282,7 +284,9 @@ def test_sdr_binary_mask(self, num_channels):
golden_sdr = 0
for b in range(batch_size):
sdr = [
- calculate_sdr_numpy(estimate=estimate[b, m, mask[b, :] > 0], target=target[b, m, mask[b, :] > 0])
+ calculate_sdr_numpy(
+ estimate=estimate[b, m, mask[b, m, :] > 0], target=target[b, m, mask[b, m, :] > 0]
+ )
for m in range(num_channels)
]
sdr = np.mean(np.array(sdr))
@@ -467,3 +471,187 @@ def test_sdr_convolution_invariant(self, num_channels: int, filter_length: int):
assert np.allclose(
uut_sdr_loss.cpu().detach().numpy(), -golden_sdr, atol=atol
), f'SDRLoss not matching for example {n}'
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize('num_channels', [1, 4])
+ @pytest.mark.parametrize('ndim', [3, 4])
+ def test_mse(self, num_channels: int, ndim: int):
+ """Test SDR calculation
+ """
+ batch_size = 8
+ num_samples = 50
+ num_features = 123
+ num_batches = 10
+ random_seed = 42
+ atol = 1e-6
+
+ signal_shape = (
+ (batch_size, num_channels, num_features, num_samples)
+ if ndim == 4
+ else (batch_size, num_channels, num_samples)
+ )
+
+ reduction_dim = (-2, -1) if ndim == 4 else -1
+
+ mse_loss = MSELoss(ndim=ndim)
+
+ _rng = np.random.default_rng(seed=random_seed)
+
+ for n in range(num_batches):
+
+ # Generate random signal
+ target = _rng.normal(size=signal_shape)
+ # Random noise + scaling
+ noise = _rng.uniform(low=0.01, high=1) * _rng.normal(size=signal_shape)
+ # Estimate
+ estimate = target + noise
+
+ # DC bias for both
+ target += _rng.uniform(low=-1, high=1)
+ estimate += _rng.uniform(low=-1, high=1)
+
+ # Tensors for testing the loss
+ tensor_estimate = torch.tensor(estimate)
+ tensor_target = torch.tensor(target)
+
+ # Reference MSE
+ golden_mse = np.zeros((batch_size, num_channels))
+ for b in range(batch_size):
+ for m in range(num_channels):
+ err = estimate[b, m, :] - target[b, m, :]
+ golden_mse[b, m] = np.mean(np.abs(err) ** 2, axis=reduction_dim)
+
+ # Calculate MSE in torch
+ uut_mse = calculate_mse_batch(estimate=tensor_estimate, target=tensor_target)
+
+ # Calculate MSE loss
+ uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target)
+
+ # Compare torch SDR vs numpy
+ assert np.allclose(
+ uut_mse.cpu().detach().numpy(), golden_mse, atol=atol
+ ), f'MSE not matching for example {n}'
+
+ # Compare SDR loss vs average of torch SDR
+ assert np.isclose(uut_mse_loss, uut_mse.mean(), atol=atol), f'MSELoss not matching for example {n}'
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize('num_channels', [1, 4])
+ @pytest.mark.parametrize('ndim', [3, 4])
+ def test_mse_weighted(self, num_channels: int, ndim: int):
+ """Test SDR calculation with weighting for channels
+ """
+ batch_size = 8
+ num_samples = 50
+ num_features = 123
+ num_batches = 10
+ random_seed = 42
+ atol = 1e-6
+
+ signal_shape = (
+ (batch_size, num_channels, num_features, num_samples)
+ if ndim == 4
+ else (batch_size, num_channels, num_samples)
+ )
+
+ reduction_dim = (-2, -1) if ndim == 4 else -1
+
+ _rng = np.random.default_rng(seed=random_seed)
+
+ channel_weight = _rng.uniform(low=0.01, high=1.0, size=num_channels)
+ channel_weight = channel_weight / np.sum(channel_weight)
+ mse_loss = MSELoss(weight=channel_weight, ndim=ndim)
+
+ for n in range(num_batches):
+
+ # Generate random signal
+ target = _rng.normal(size=signal_shape)
+ # Random noise + scaling
+ noise = _rng.uniform(low=0.001, high=10) * _rng.normal(size=target.shape)
+ # Estimate
+ estimate = target + noise
+
+ # Tensors for testing the loss
+ tensor_estimate = torch.tensor(estimate)
+ tensor_target = torch.tensor(target)
+
+ # Reference MSE
+ golden_mse = 0
+ for b in range(batch_size):
+ mse = [
+ np.mean(np.abs(estimate[b, m, :] - target[b, m, :]) ** 2, axis=reduction_dim)
+ for m in range(num_channels)
+ ]
+ # weighted sum
+ mse = np.sum(np.array(mse) * channel_weight)
+ golden_mse += mse
+ golden_mse /= batch_size # average over batch
+
+ # Calculate MSE loss
+ uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target)
+
+ # Compare
+ assert np.allclose(
+ uut_mse_loss.cpu().detach().numpy(), golden_mse, atol=atol
+ ), f'MSELoss not matching for example {n}'
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize('num_channels', [1, 4])
+ @pytest.mark.parametrize('ndim', [3, 4])
+ def test_mse_input_length(self, num_channels: int, ndim: int):
+ """Test SDR calculation with input length.
+ """
+ batch_size = 8
+ max_num_samples = 50
+ num_features = 123
+ num_batches = 10
+ random_seed = 42
+ atol = 1e-6
+
+ signal_shape = (
+ (batch_size, num_channels, num_features, max_num_samples)
+ if ndim == 4
+ else (batch_size, num_channels, max_num_samples)
+ )
+
+ reduction_dim = (-2, -1) if ndim == 4 else -1
+
+ _rng = np.random.default_rng(seed=random_seed)
+
+ mse_loss = MSELoss(ndim=ndim)
+
+ for n in range(num_batches):
+
+ # Generate random signal
+ target = _rng.normal(size=signal_shape)
+ # Random noise + scaling
+ noise = _rng.uniform(low=0.001, high=10) * _rng.normal(size=target.shape)
+ # Estimate
+ estimate = target + noise
+
+ # Limit calculation to random input_length samples
+ input_length = _rng.integers(low=1, high=max_num_samples, size=batch_size)
+
+ # Tensors for testing the loss
+ tensor_estimate = torch.tensor(estimate)
+ tensor_target = torch.tensor(target)
+ tensor_input_length = torch.tensor(input_length)
+
+ # Reference MSE
+ golden_mse = 0
+ for b, b_len in enumerate(input_length):
+ mse = [
+ np.mean(np.abs(estimate[b, m, ..., :b_len] - target[b, m, ..., :b_len]) ** 2, axis=reduction_dim)
+ for m in range(num_channels)
+ ]
+ mse = np.mean(np.array(mse))
+ golden_mse += mse
+ golden_mse /= batch_size # average over batch
+
+ # Calculate MSE
+ uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target, input_length=tensor_input_length)
+
+ # Compare
+ assert np.allclose(
+ uut_mse_loss.cpu().detach().numpy(), golden_mse, atol=atol
+ ), f'MSELoss not matching for example {n}'
diff --git a/tests/collections/asr/test_audio_preprocessing.py b/tests/collections/asr/test_audio_preprocessing.py
index b0875936a7f7..600b9fed44fa 100644
--- a/tests/collections/asr/test_audio_preprocessing.py
+++ b/tests/collections/asr/test_audio_preprocessing.py
@@ -155,7 +155,11 @@ def test_spec_to_audio(self, fft_length: int, num_channels: int):
@pytest.mark.skipif(not HAVE_TORCHAUDIO, reason="Modules in this test require torchaudio")
@pytest.mark.parametrize('fft_length', [128, 1024])
@pytest.mark.parametrize('num_channels', [1, 4])
- def test_audio_to_spectrogram_reconstruction(self, fft_length: int, num_channels: int):
+ @pytest.mark.parametrize('magnitude_power', [0.5, 1, 2])
+ @pytest.mark.parametrize('scale', [0.1, 1.0])
+ def test_audio_to_spectrogram_reconstruction(
+ self, fft_length: int, num_channels: int, magnitude_power: float, scale: float
+ ):
"""Test analysis and synthesis transform result in a perfect reconstruction.
"""
batch_size = 4
@@ -169,8 +173,12 @@ def test_audio_to_spectrogram_reconstruction(self, fft_length: int, num_channels
hop_lengths = [fft_length // 2, fft_length // 4]
for hop_length in hop_lengths:
- audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length)
- spec2audio = SpectrogramToAudio(fft_length=fft_length, hop_length=hop_length)
+ audio2spec = AudioToSpectrogram(
+ fft_length=fft_length, hop_length=hop_length, magnitude_power=magnitude_power, scale=scale
+ )
+ spec2audio = SpectrogramToAudio(
+ fft_length=fft_length, hop_length=hop_length, magnitude_power=magnitude_power, scale=scale
+ )
for n in range(num_examples):
x = _rng.normal(size=(batch_size, num_channels, num_samples))
diff --git a/tests/setup/__main__.py b/tests/setup/__main__.py
index 289a2537e2f2..a08ccdaa1634 100644
--- a/tests/setup/__main__.py
+++ b/tests/setup/__main__.py
@@ -34,8 +34,8 @@
)
create_hf_model(
- model_name_or_path="/home/TestData/nlp/meta-llama/Llama-2-7b-hf",
- output_dir=os.path.join(args.save_dir, "megatron_llama/llama-ci-hf"),
+ model_name_or_path="/home/TestData/nlp/megatron_llama/llama-ci-hf",
+ output_dir=os.path.join(args.save_dir, "megatron_llama/llama-ci-hf-tiny"),
config_updates={"hidden_size": 256, "num_attention_heads": 4, "num_hidden_layers": 2, "num_key_value_heads": 4},
overwrite=args.overwrite,
)