Skip to content

Commit

Permalink
Score-based generative enhancement model (NVIDIA#8567)
Browse files Browse the repository at this point in the history
* Score-based generative enhancement model in NeMo
* Addressed comments, added unit test

Signed-off-by: Ante Jukić <ajukic@nvidia.com>
  • Loading branch information
anteju committed May 1, 2024
1 parent 3d87ed7 commit f658b6f
Show file tree
Hide file tree
Showing 21 changed files with 2,985 additions and 349 deletions.
18 changes: 17 additions & 1 deletion examples/audio_tasks/audio_to_audio_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import json
import os
import tempfile
from collections import defaultdict
from dataclasses import dataclass, field, is_dataclass
from typing import List, Optional

Expand Down Expand Up @@ -101,6 +102,9 @@ class AudioEvaluationConfig(process_audio.ProcessConfig):
# Metrics to calculate
metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi'])

# Return metric values for each example
return_values_per_example: bool = False


def get_evaluation_dataloader(config):
"""Prepare a dataloader for evaluation.
Expand Down Expand Up @@ -174,6 +178,9 @@ def main(cfg: AudioEvaluationConfig):
# Setup metrics
metrics = get_metrics(cfg)

if cfg.return_values_per_example and cfg.batch_size > 1:
raise ValueError('return_example_values is only supported for batch_size=1.')

# Processing
if not cfg.only_score_manifest:
# Process audio using the configured model and save in the output directory
Expand Down Expand Up @@ -236,6 +243,10 @@ def main(cfg: AudioEvaluationConfig):

num_files += 1

if cfg.max_utts is not None and num_files >= cfg.max_utts:
logging.info('Reached max_utts: %s', cfg.max_utts)
break

# Prepare dataloader
config = {
'manifest_filepath': temporary_manifest_filepath,
Expand All @@ -249,6 +260,8 @@ def main(cfg: AudioEvaluationConfig):
}
temporary_dataloader = get_evaluation_dataloader(config)

metrics_value_per_example = defaultdict(list)

# Calculate metrics
for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'):
processed_signal, processed_length, target_signal, target_length = eval_batch
Expand All @@ -257,7 +270,9 @@ def main(cfg: AudioEvaluationConfig):
raise RuntimeError(f'Length mismatch.')

for name, metric in metrics.items():
metric.update(preds=processed_signal, target=target_signal, input_length=target_length)
value = metric(preds=processed_signal, target=target_signal, input_length=target_length)
if cfg.return_values_per_example:
metrics_value_per_example[name].append(value.item())

# Convert to a dictionary with name: value
metrics_value = {name: metric.compute().item() for name, metric in metrics.items()}
Expand All @@ -277,6 +292,7 @@ def main(cfg: AudioEvaluationConfig):
# Inject the metric name and score into the config, and return the entire config
with open_dict(cfg):
cfg.metrics_value = metrics_value
cfg.metrics_value_per_example = dict(metrics_value_per_example)

return cfg

Expand Down
1 change: 0 additions & 1 deletion examples/audio_tasks/conf/beamforming.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ model:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram
power: null

decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
Expand Down
3 changes: 0 additions & 3 deletions examples/audio_tasks/conf/masking.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
#
name: "masking"

model:
Expand Down Expand Up @@ -44,7 +42,6 @@ model:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram
power: null

decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
Expand Down
130 changes: 130 additions & 0 deletions examples/audio_tasks/conf/predictive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
name: "predictive_model"

model:
type: predictive
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
normalize_input: true # normalize the input signal to 0dBFS

train_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256
random_offset: true
normalization_signal: input_signal
batch_size: 8 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
batch_size: 8
shuffle: false
num_workers: 4
pin_memory: true

encoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
hop_length: 128
magnitude_power: 0.5
scale: 0.33

decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}
magnitude_power: ${model.encoder.magnitude_power}
scale: ${model.encoder.scale}

estimator:
_target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus
in_channels: 1 # single-channel noisy input
out_channels: 1 # single-channel estimate
num_res_blocks: 3 # increased number of res blocks
pad_time_to: 64 # pad to 64 frames for the time dimension
pad_dimension_to: 0 # no padding in the frequency dimension

loss:
_target_: nemo.collections.asr.losses.MSELoss # computed in the time domain

metrics:
val:
sisdr: # output SI-SDR
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio

optim:
name: adam
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}

# use exponential moving average for model parameters
ema:
enable: true
decay: 0.999 # decay rate
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
every_n_steps: 1 # how often to update EMA weights
validate_original_weights: False # use original weights for validation calculation?

# logging
create_tensorboard_logger: true

# checkpointing
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: val_sisdr
mode: max
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

# early stopping
create_early_stopping_callback: true
early_stopping_callback_params:
monitor: val_sisdr
mode: max
min_delta: 0.0
patience: 20 # patience in terms of check_val_every_n_epoch
verbose: true
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
149 changes: 149 additions & 0 deletions examples/audio_tasks/conf/score_based_generative.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
name: score_based_generative_model

model:
type: score_based
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
normalize_input: true
max_utts_evaluation_metrics: 50 # metric calculation needs full inference and is slow, so we limit to first few files

train_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256
random_offset: true
normalization_signal: input_signal
batch_size: 8 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
normalize_input: false # load data as is for validation, the model will normalize it for inference
batch_size: 4
shuffle: false
num_workers: 4
pin_memory: true

encoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
hop_length: 128
magnitude_power: 0.5
scale: 0.33

decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}
magnitude_power: ${model.encoder.magnitude_power}
scale: ${model.encoder.scale}

estimator:
_target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus
in_channels: 2 # concatenation of single-channel perturbed and noisy
out_channels: 1 # single-channel score estimate
conditioned_on_time: true
num_res_blocks: 3 # increased number of res blocks
pad_time_to: 64 # pad to 64 frames for the time dimension
pad_dimension_to: 0 # no padding in the frequency dimension

sde:
_target_: nemo.collections.asr.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE
stiffness: 1.5
std_min: 0.05
std_max: 0.5
num_steps: 1000

sampler:
_target_: nemo.collections.asr.parts.submodules.diffusion.PredictorCorrectorSampler
predictor: reverse_diffusion
corrector: annealed_langevin_dynamics
num_steps: 50
num_corrector_steps: 1
snr: 0.5

loss:
_target_: nemo.collections.asr.losses.MSELoss
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time)

metrics:
val:
sisdr: # output SI-SDR
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio

optim:
name: adam
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}

# use exponential moving average for model parameters
ema:
enable: true
decay: 0.999 # decay rate
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
every_n_steps: 1 # how often to update EMA weights
validate_original_weights: false # use original weights for validation calculation?

# logging
create_tensorboard_logger: true

# checkpointing
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: val_sisdr
mode: max
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

# early stopping
create_early_stopping_callback: true
early_stopping_callback_params:
monitor: val_sisdr
mode: max
min_delta: 0.0
patience: 20 # patience in terms of check_val_every_n_epoch
verbose: true
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Loading

0 comments on commit f658b6f

Please sign in to comment.