Skip to content

Commit

Permalink
[TTS] created the finetuning Hifigan 44100Hz recipe on HUI-Audio-Corp…
Browse files Browse the repository at this point in the history
…us-German. (#4478)

* implemented a script of generating mel-spectrograms for finetuning Hifigan using multiprocessing.
* fixed some typos.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
  • Loading branch information
XuesongYang authored Jun 30, 2022
1 parent fac634c commit 23f6a95
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 15 deletions.
8 changes: 4 additions & 4 deletions examples/tts/conf/hifigan/hifigan_44100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ highfreq: null
window: hann

train_n_segments: 16384
train_max_duration: null
train_max_duration: null # change to null to include longer audios.
train_min_duration: 0.75

val_n_segments: 131072
Expand Down Expand Up @@ -70,13 +70,13 @@ model:

trainer:
num_nodes: 1
devices: 1
devices: -1
accelerator: gpu
strategy: ddp
precision: 32
precision: 16
max_steps: ${model.max_steps}
accumulate_grad_batches: 1
enable_checkpointing: False # Provided by exp_manager
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 10
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/preprocessing/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def segment_from_file(cls, audio_file, target_sr=None, n_segments=0, trim=False,
try:
with sf.SoundFile(audio_file, 'r') as f:
sample_rate = f.samplerate
if n_segments > 0 and len(f) > n_segments:
if 0 < n_segments < len(f):
max_audio_start = len(f) - n_segments
audio_start = random.randint(0, max_audio_start)
f.seek(audio_start)
Expand Down
11 changes: 6 additions & 5 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,15 +818,15 @@ def __init__(
):
"""Dataset which can be used for training and fine-tuning vocoder with pre-computed mel-spectrograms.
Args:
manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing information on the
dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
json. Each line should contain the following:
manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing
information on the dataset. Each line in the .json file should be valid json. Note: the .json file itself
is not valid json. Each line should contain the following:
"audio_filepath": <PATH_TO_WAV>,
"duration": <Duration of audio clip in seconds> (Optional),
"mel_filepath": <PATH_TO_LOG_MEL> (Optional, can be in .npy (numpy.save) or .pt (torch.save) format)
sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
n_segments (int): The length of audio in samples to load. For example, given a sample rate of 16kHz, and
n_segments=16000, a random 1 second section of audio from the clip will be loaded. The section will
n_segments=16000, a random 1-second section of audio from the clip will be loaded. The section will
be randomly sampled everytime the audio is batched. Can be set to None to load the entire audio.
Must be specified if load_precomputed_mel is True.
max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
Expand All @@ -838,7 +838,8 @@ def __init__(
ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
that will be pruned prior to training. Defaults to None which does not prune.
trim (bool): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
load_precomputed_mel (bool): Whether to load precomputed mel (useful for fine-tuning). Note: Requires "mel_filepath" to be set in the manifest file.
load_precomputed_mel (bool): Whether to load precomputed mel (useful for fine-tuning).
Note: Requires "mel_filepath" to be set in the manifest file.
hop_length (Optional[int]): The hope length between fft computations. Must be specified if load_precomputed_mel is True.
"""
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class BetaBinomialInterpolator:
"""
This module calculates alignment prior matrices (based on beta-binomial distribution) using cached popular sizes and image interpolation.
The implementation is taken from https://github.com/NVIDIA/DeepLearningExamples.
The implementation is taken from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py
"""

def __init__(self, round_mel_len_to=50, round_text_len_to=10, cache_size=500):
Expand Down
6 changes: 3 additions & 3 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def save_to(self, save_path: str):
.nemo file is an archive (tar.gz) with the following:
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
model_wights.chpt - model checkpoint
model_wights.ckpt - model checkpoint
Args:
save_path: Path to .nemo file where model instance should be saved
Expand Down Expand Up @@ -285,7 +285,7 @@ def restore_from(
model as an OmegaConf DictConfig object without instantiating the model.
trainer: Optional, a pytorch lightning Trainer object that will be forwarded to the
instantiated model's constructor.
save_restore_connector (SaveRestoreConnector): Can be overrided to add custom save and restore logic.
save_restore_connector (SaveRestoreConnector): Can be overridden to add custom save and restore logic.
Example:
```
Expand Down Expand Up @@ -331,7 +331,7 @@ def load_from_checkpoint(
):
"""
Loads ModelPT from checkpoint, with some maintenance of restoration.
For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
For documentation, please refer to LightningModule.load_from_checkpoint() documentation.
"""
checkpoint = None
try:
Expand Down
2 changes: 1 addition & 1 deletion nemo/core/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def compute_max_steps(
logging.warning(
"Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released"
)
# TODO: Master verion, not in pytorch 1.6.0
# TODO: Master version, not in pytorch 1.6.0
# sampler_num_samples = math.ceil((num_samples - num_workers)/ num_workers)

steps_per_epoch = _round(sampler_num_samples / batch_size)
Expand Down
163 changes: 163 additions & 0 deletions scripts/dataset_processing/tts/generate_mels_for_finetune_hifigan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script is to generate mel spectrograms from a Fastpitch model checkpoint. Please see general usage below. It runs
on GPUs by default, but you can add `--num-workers 5 --cpu` as an option to run on CPUs.
$ python scripts/dataset_processing/tts/generate_mels_for_finetune_hifigan.py \
--fastpitch-model-ckpt ./models/fastpitch/multi_spk/FastPitch--v_loss\=1.4473-epoch\=209.ckpt \
--input-json-manifests /home/xueyang/HUI-Audio-Corpus-German-clean/test_manifest_text_normed_phonemes.json
--output-json-manifest-root /home/xueyang/experiments/multi_spk_tts_de
"""

import argparse
import json
from pathlib import Path

import numpy as np
import soundfile as sf
import torch
from joblib import Parallel, delayed
from tqdm import tqdm

from nemo.collections.tts.models import FastPitchModel
from nemo.collections.tts.torch.helpers import BetaBinomialInterpolator
from nemo.utils import logging


def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Generate mel spectrograms with pretrained FastPitch model, and create manifests for finetuning Hifigan.",
)
parser.add_argument(
"--fastpitch-model-ckpt", required=True, type=Path, help="Specify a full path of a fastpitch model checkpoint."
)
parser.add_argument(
"--input-json-manifests",
nargs="+",
required=True,
type=Path,
help="Specify a full path of a JSON manifest. You could add multiple manifests.",
)
parser.add_argument(
"--output-json-manifest-root",
required=True,
type=Path,
help="Specify a full path of output root that would contain new manifests.",
)
parser.add_argument(
"--num-workers",
default=-1,
type=int,
help="Specify the max number of concurrently Python workers processes. "
"If -1 all CPUs are used. If 1 no parallel computing is used.",
)
parser.add_argument("--cpu", action='store_true', default=False, help="Generate mel spectrograms using CPUs.")
args = parser.parse_args()
return args


def __load_wav(audio_file):
with sf.SoundFile(audio_file, 'r') as f:
samples = f.read(dtype='float32')
return samples.transpose()


def __generate_mels(entry, spec_model, device, beta_binomial_interpolator, mel_root):
# Generate a spectrograms (we need to use ground truth alignment for correct matching between audio and mels)
audio = __load_wav(entry["audio_filepath"])
audio = torch.from_numpy(audio).unsqueeze(0).to(device)
audio_len = torch.tensor(audio.shape[1], dtype=torch.long, device=device).unsqueeze(0)

if spec_model.fastpitch.speaker_emb is not None and "speaker" in entry:
speaker = torch.tensor([entry['speaker']]).to(device)
else:
speaker = None

with torch.no_grad():
if "normalized_text" in entry:
text = spec_model.parse(entry["normalized_text"], normalize=False)
else:
text = spec_model.parse(entry['text'])

text_len = torch.tensor(text.shape[-1], dtype=torch.long, device=device).unsqueeze(0)
spect, spect_len = spec_model.preprocessor(input_signal=audio, length=audio_len)

# Generate attention prior and spectrogram inputs for HiFi-GAN
attn_prior = (
torch.from_numpy(beta_binomial_interpolator(spect_len.item(), text_len.item()))
.unsqueeze(0)
.to(text.device)
)

spectrogram = spec_model.forward(
text=text, input_lens=text_len, spec=spect, mel_lens=spect_len, attn_prior=attn_prior, speaker=speaker,
)[0]

save_path = mel_root / f"{Path(entry['audio_filepath']).stem}.npy"
np.save(save_path, spectrogram[0].to('cpu').numpy())
entry["mel_filepath"] = str(save_path)

return entry


def main():
args = get_args()
ckpt_path = args.fastpitch_model_ckpt
input_manifest_filepaths = args.input_json_manifests
output_json_manifest_root = args.output_json_manifest_root

mel_root = output_json_manifest_root / "mels"
mel_root.mkdir(exist_ok=True, parents=True)

# load pretrained FastPitch model checkpoint
spec_model = FastPitchModel.load_from_checkpoint(ckpt_path)
spec_model.eval()
if args.cpu:
spec_model.eval()
else:
spec_model.eval().cuda()
device = spec_model.device

beta_binomial_interpolator = BetaBinomialInterpolator()

for manifest in input_manifest_filepaths:
logging.info(f"Processing {manifest}.")
entries = []
with open(manifest, "r") as fjson:
for line in fjson:
entries.append(json.loads(line.strip()))

if device == "cpu":
new_entries = Parallel(n_jobs=args.num_workers)(
delayed(__generate_mels)(entry, spec_model, device, beta_binomial_interpolator, mel_root)
for entry in entries
)
else:
new_entries = []
for entry in tqdm(entries):
new_entry = __generate_mels(entry, spec_model, device, beta_binomial_interpolator, mel_root)
new_entries.append(new_entry)

mel_manifest_path = output_json_manifest_root / f"{manifest.stem}_mel{manifest.suffix}"
with open(mel_manifest_path, "w") as fmel:
for entry in new_entries:
fmel.write(json.dumps(entry) + "\n")
logging.info(f"Processing {manifest} is complete --> {mel_manifest_path}")


if __name__ == "__main__":
main()

0 comments on commit 23f6a95

Please sign in to comment.