forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Score-based generative enhancement model (NVIDIA#8567)
* Score-based generative enhancement model in NeMo * Addressed comments, added unit test Signed-off-by: Ante Jukić <ajukic@nvidia.com>
- Loading branch information
Showing
21 changed files
with
2,985 additions
and
349 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
name: "predictive_model" | ||
|
||
model: | ||
type: predictive | ||
sample_rate: 16000 | ||
skip_nan_grad: false | ||
num_outputs: 1 | ||
normalize_input: true # normalize the input signal to 0dBFS | ||
|
||
train_ds: | ||
manifest_filepath: ??? | ||
input_key: noisy_filepath | ||
target_key: clean_filepath | ||
audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256 | ||
random_offset: true | ||
normalization_signal: input_signal | ||
batch_size: 8 # batch size may be increased based on the available memory | ||
shuffle: true | ||
num_workers: 8 | ||
pin_memory: true | ||
|
||
validation_ds: | ||
manifest_filepath: ??? | ||
input_key: noisy_filepath | ||
target_key: clean_filepath | ||
batch_size: 8 | ||
shuffle: false | ||
num_workers: 4 | ||
pin_memory: true | ||
|
||
encoder: | ||
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram | ||
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 | ||
hop_length: 128 | ||
magnitude_power: 0.5 | ||
scale: 0.33 | ||
|
||
decoder: | ||
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio | ||
fft_length: ${model.encoder.fft_length} | ||
hop_length: ${model.encoder.hop_length} | ||
magnitude_power: ${model.encoder.magnitude_power} | ||
scale: ${model.encoder.scale} | ||
|
||
estimator: | ||
_target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus | ||
in_channels: 1 # single-channel noisy input | ||
out_channels: 1 # single-channel estimate | ||
num_res_blocks: 3 # increased number of res blocks | ||
pad_time_to: 64 # pad to 64 frames for the time dimension | ||
pad_dimension_to: 0 # no padding in the frequency dimension | ||
|
||
loss: | ||
_target_: nemo.collections.asr.losses.MSELoss # computed in the time domain | ||
|
||
metrics: | ||
val: | ||
sisdr: # output SI-SDR | ||
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio | ||
|
||
optim: | ||
name: adam | ||
lr: 1e-4 | ||
# optimizer arguments | ||
betas: [0.9, 0.999] | ||
weight_decay: 0.0 | ||
|
||
trainer: | ||
devices: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: -1 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: null | ||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. | ||
log_every_n_steps: 25 # Interval of logging. | ||
enable_progress_bar: true | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: false # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
|
||
# use exponential moving average for model parameters | ||
ema: | ||
enable: true | ||
decay: 0.999 # decay rate | ||
cpu_offload: false # offload EMA parameters to CPU to save GPU memory | ||
every_n_steps: 1 # how often to update EMA weights | ||
validate_original_weights: False # use original weights for validation calculation? | ||
|
||
# logging | ||
create_tensorboard_logger: true | ||
|
||
# checkpointing | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: val_sisdr | ||
mode: max | ||
save_top_k: 5 | ||
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
||
# early stopping | ||
create_early_stopping_callback: true | ||
early_stopping_callback_params: | ||
monitor: val_sisdr | ||
mode: max | ||
min_delta: 0.0 | ||
patience: 20 # patience in terms of check_val_every_n_epoch | ||
verbose: true | ||
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. | ||
|
||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
# you need to set these two to true to continue the training | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false | ||
|
||
# You may use this section to create a W&B logger | ||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: null | ||
project: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
name: score_based_generative_model | ||
|
||
model: | ||
type: score_based | ||
sample_rate: 16000 | ||
skip_nan_grad: false | ||
num_outputs: 1 | ||
normalize_input: true | ||
max_utts_evaluation_metrics: 50 # metric calculation needs full inference and is slow, so we limit to first few files | ||
|
||
train_ds: | ||
manifest_filepath: ??? | ||
input_key: noisy_filepath | ||
target_key: clean_filepath | ||
audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256 | ||
random_offset: true | ||
normalization_signal: input_signal | ||
batch_size: 8 # batch size may be increased based on the available memory | ||
shuffle: true | ||
num_workers: 8 | ||
pin_memory: true | ||
|
||
validation_ds: | ||
manifest_filepath: ??? | ||
input_key: noisy_filepath | ||
target_key: clean_filepath | ||
normalize_input: false # load data as is for validation, the model will normalize it for inference | ||
batch_size: 4 | ||
shuffle: false | ||
num_workers: 4 | ||
pin_memory: true | ||
|
||
encoder: | ||
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram | ||
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 | ||
hop_length: 128 | ||
magnitude_power: 0.5 | ||
scale: 0.33 | ||
|
||
decoder: | ||
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio | ||
fft_length: ${model.encoder.fft_length} | ||
hop_length: ${model.encoder.hop_length} | ||
magnitude_power: ${model.encoder.magnitude_power} | ||
scale: ${model.encoder.scale} | ||
|
||
estimator: | ||
_target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus | ||
in_channels: 2 # concatenation of single-channel perturbed and noisy | ||
out_channels: 1 # single-channel score estimate | ||
conditioned_on_time: true | ||
num_res_blocks: 3 # increased number of res blocks | ||
pad_time_to: 64 # pad to 64 frames for the time dimension | ||
pad_dimension_to: 0 # no padding in the frequency dimension | ||
|
||
sde: | ||
_target_: nemo.collections.asr.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE | ||
stiffness: 1.5 | ||
std_min: 0.05 | ||
std_max: 0.5 | ||
num_steps: 1000 | ||
|
||
sampler: | ||
_target_: nemo.collections.asr.parts.submodules.diffusion.PredictorCorrectorSampler | ||
predictor: reverse_diffusion | ||
corrector: annealed_langevin_dynamics | ||
num_steps: 50 | ||
num_corrector_steps: 1 | ||
snr: 0.5 | ||
|
||
loss: | ||
_target_: nemo.collections.asr.losses.MSELoss | ||
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) | ||
|
||
metrics: | ||
val: | ||
sisdr: # output SI-SDR | ||
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio | ||
|
||
optim: | ||
name: adam | ||
lr: 1e-4 | ||
# optimizer arguments | ||
betas: [0.9, 0.999] | ||
weight_decay: 0.0 | ||
|
||
trainer: | ||
devices: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: -1 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: null | ||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. | ||
log_every_n_steps: 25 # Interval of logging. | ||
enable_progress_bar: true | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: false # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
|
||
# use exponential moving average for model parameters | ||
ema: | ||
enable: true | ||
decay: 0.999 # decay rate | ||
cpu_offload: false # offload EMA parameters to CPU to save GPU memory | ||
every_n_steps: 1 # how often to update EMA weights | ||
validate_original_weights: false # use original weights for validation calculation? | ||
|
||
# logging | ||
create_tensorboard_logger: true | ||
|
||
# checkpointing | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: val_sisdr | ||
mode: max | ||
save_top_k: 5 | ||
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
||
# early stopping | ||
create_early_stopping_callback: true | ||
early_stopping_callback_params: | ||
monitor: val_sisdr | ||
mode: max | ||
min_delta: 0.0 | ||
patience: 20 # patience in terms of check_val_every_n_epoch | ||
verbose: true | ||
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. | ||
|
||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
# you need to set these two to true to continue the training | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false | ||
|
||
# You may use this section to create a W&B logger | ||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: null | ||
project: null |
Oops, something went wrong.