Skip to content

Commit

Permalink
v0.3.1: add fine-tune training for hifigan
Browse files Browse the repository at this point in the history
  • Loading branch information
ntt123 committed May 28, 2021
1 parent 47c5ebb commit 1006776
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
gta/
train_data/
test_data/
assets/infore/
Expand Down
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ git clone https://github.com/jik876/hifi-gan.git
# create dataset in hifi-gan format
ln -sf `pwd`/train_data hifi-gan/data
cd hifi-gan/data
ls -1 *.wav | sed -e 's/\.wav$//' > files.txt
ls -1 *.TextGrid | sed -e 's/\.TextGrid$//' > files.txt
cd ..
head -n 100 data/files.txt > val_files.txt
tail -n +101 data/files.txt > train_files.txt
Expand All @@ -74,6 +74,22 @@ python3 train.py \
--input_validation_file=val_files.txt
```

Finetune on Ground-Truth Aligned melspectrograms:
```sh
cd /path/to/vietTTS # go to vietTTS directory
python3 -m vietTTS.nat.zero_silence_segments -o train_data # zero all [sil, sp, spn] segments
python3 -m vietTTS.nat.gta -o /path/to/hifi-gan/ft_dataset # create gta melspectrograms at hifi-gan/ft_dataset directory

# turn on finetune
cd /path/to/hifi-gan
python3 train.py \
--fine_tuning True \
--config ../assets/hifigan/config.json \
--input_wavs_dir=data \
--input_training_file=train_files.txt \
--input_validation_file=val_files.txt
```

Then, use the following command to convert pytorch model to haiku format:
```sh
cd ..
Expand Down
Binary file modified assets/infore/clip.wav
Binary file not shown.
2 changes: 1 addition & 1 deletion scripts/download_aligned_infore_dataset.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ pushd .
mkdir -p $data_root
cd $data_root
gdown --id 1Pe-5lKT_lZsliv2WxQDai2mjhI9ZMFlj -O infore.zip
unzip infore.zip
unzip -q infore.zip
popd
6 changes: 3 additions & 3 deletions scripts/quick_start.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
if [ ! -f assets/infore/hifigan/g_00500000 ]; then
if [ ! -f assets/infore/hifigan/g_00800000 ]; then
pip3 install gdown
echo "Downloading models..."
mkdir -p -p assets/infore/{nat,hifigan}
gdown --id 16UhN8QBxG1YYwUh8smdEeVnKo9qZhvZj -O assets/infore/nat/duration_ckpt_latest.pickle
gdown --id 1-8Ig65S3irNHSzcskT37SLgeyuUhjKdj -O assets/infore/nat/acoustic_ckpt_latest.pickle
gdown --id 10jFFokGGD9hQG4pzPB443pf8keEt7Pgx -O assets/infore/hifigan/g_00500000
python3 -m vietTTS.hifigan.convert_torch_model_to_haiku --config-file=assets/hifigan/config.json --checkpoint-file=assets/infore/hifigan/g_00500000
gdown --id 10SFOlAduG20TdjGC5e1Jod_vJIpxkD6u -O assets/infore/hifigan/g_00800000
python3 -m vietTTS.hifigan.convert_torch_model_to_haiku --config-file=assets/hifigan/config.json --checkpoint-file=assets/infore/hifigan/g_00800000
fi

echo "Generate audio clip"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

__version__ = '0.3.0'
__version__ = '0.3.1'
url = 'https://github.com/ntt123/vietTTS'

install_requires = ['tabulate', 'optax', 'jax', 'jaxlib', 'einops', 'librosa',
Expand Down
15 changes: 1 addition & 14 deletions vietTTS/nat/acoustic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,6 @@ def net(x): return AcousticModel(is_training=True)(x)
def val_net(x): return AcousticModel(is_training=False)(x)


@jax.jit
def val_forward(params, aux, rng, inputs: AcousticInput):
melfilter = MelFilter(FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax)
mels = melfilter(inputs.wavs.astype(jnp.float32) / (2**15))
B, L, D = mels.shape
inp_mels = jnp.concatenate((jnp.zeros((B, 1, D), dtype=jnp.float32), mels[:, :-1, :]), axis=1)

n_frames = inputs.durations * FLAGS.sample_rate / (FLAGS.n_fft//4)
inputs = inputs._replace(mels=inp_mels, durations=n_frames)
(mel1_hat, mel2_hat), new_aux = val_net.apply(params, aux, rng, inputs)
return mel1_hat, mel2_hat


def loss_fn(params, aux, rng, inputs: AcousticInput, is_training=True):
melfilter = MelFilter(FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax)
mels = melfilter(inputs.wavs.astype(jnp.float32) / (2**15))
Expand All @@ -49,7 +36,7 @@ def loss_fn(params, aux, rng, inputs: AcousticInput, is_training=True):
loss1 = (jnp.square(mel1_hat - mels) + jnp.square(mel2_hat - mels)) / 2
loss2 = (jnp.abs(mel1_hat - mels) + jnp.abs(mel2_hat - mels)) / 2
loss = jnp.mean((loss1 + loss2)/2, axis=-1)
mask = (jnp.arange(0, L)[None, :] - 10) < (inputs.wav_lengths // (FLAGS.n_fft // 4))[:, None]
mask = jnp.arange(0, L)[None, :] < (inputs.wav_lengths // (FLAGS.n_fft // 4))[:, None]
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return (loss, new_aux) if is_training else (loss, new_aux, mel2_hat, mels)

Expand Down
24 changes: 16 additions & 8 deletions vietTTS/nat/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ def load_textgrid_wav(data_dir: Path, token_seq_len: int, batch_size, pad_wav_le
tg_files = sorted(data_dir.glob('*.TextGrid'))
random.Random(42).shuffle(tg_files)
L = len(tg_files) * 95 // 100
assert mode in ['train', 'val']
assert mode in ['train', 'val', 'gta']
phonemes = load_phonemes_set_from_lexicon_file(data_dir / 'lexicon.txt')
if mode == 'train':
if mode == 'gta':
tg_files = tg_files # all files
elif mode == 'train':
tg_files = tg_files[:L]
if mode == 'val':
elif mode == 'val':
tg_files = tg_files[L:]

data = []
Expand Down Expand Up @@ -119,19 +121,25 @@ def load_textgrid_wav(data_dir: Path, token_seq_len: int, batch_size, pad_wav_le
y = y[:pad_wav_len]
wav_length = len(y)
y = np.pad(y, (0, pad_wav_len - len(y)))
data.append((ps, ds, l, y, wav_length))
data.append((fn.stem, ps, ds, l, y, wav_length))

batch = []
while True:
random.shuffle(data)
for e in data:
for idx, e in enumerate(data):
batch.append(e)
if len(batch) == batch_size:
ps, ds, lengths, wavs, wav_lengths = zip(*batch)
if len(batch) == batch_size or (mode == 'gta' and idx == len(data) - 1):
names, ps, ds, lengths, wavs, wav_lengths = zip(*batch)
ps = np.array(ps, dtype=np.int32)
ds = np.array(ds, dtype=np.float32)
lengths = np.array(lengths, dtype=np.int32)
wavs = np.array(wavs)
wav_lengths = np.array(wav_lengths, dtype=np.int32)
yield AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
if mode == 'gta':
yield names, AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
else:
yield AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
batch = []
if mode == 'gta':
assert len(batch) == 0
break
67 changes: 67 additions & 0 deletions vietTTS/nat/gta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import pickle
from argparse import ArgumentParser
from functools import partial
from pathlib import Path

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
from vietTTS.nat.config import AcousticInput

from .config import FLAGS, AcousticInput
from .data_loader import load_textgrid_wav
from .dsp import MelFilter
from .model import AcousticModel


@hk.transform_with_state
def net(x): return AcousticModel(is_training=True)(x)


@hk.transform_with_state
def val_net(x): return AcousticModel(is_training=False)(x)


def forward_fn_(params, aux, rng, inputs: AcousticInput):
melfilter = MelFilter(FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax)
mels = melfilter(inputs.wavs.astype(jnp.float32) / (2**15))
B, L, D = mels.shape
inp_mels = jnp.concatenate((jnp.zeros((B, 1, D), dtype=jnp.float32), mels[:, :-1, :]), axis=1)
n_frames = inputs.durations * FLAGS.sample_rate / (FLAGS.n_fft//4)
inputs = inputs._replace(mels=inp_mels, durations=n_frames)
(mel1_hat, mel2_hat), new_aux = val_net.apply(params, aux, rng, inputs)
return mel2_hat


forward_fn = jax.jit(forward_fn_)


def generate_gta(out_dir: Path):
out_dir.mkdir(parents=True, exist_ok=True)
data_iter = load_textgrid_wav(FLAGS.data_dir, FLAGS.max_phoneme_seq_len,
FLAGS.batch_size, FLAGS.max_wave_len, 'gta')
ckpt_fn = FLAGS.ckpt_dir / 'acoustic_ckpt_latest.pickle'
print('Resuming from latest checkpoint at', ckpt_fn)
with open(ckpt_fn, 'rb') as f:
dic = pickle.load(f)
_, params, aux, rng, _ = dic['step'], dic['params'], dic['aux'], dic['rng'], dic['optim_state']

tr = tqdm(data_iter)
for names, batch in tr:
lengths = batch.wav_lengths
predicted_mel = forward_fn(params, aux, rng, batch)
mel = jax.device_get(predicted_mel)
for idx, fn in enumerate(names):
file = out_dir / f'{fn}.npy'
tr.write(f'saving to file {file}')
l = lengths[idx] // (FLAGS.n_fft//4)
np.save(file, mel[idx, :l].T)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-o', '--output-dir', type=Path, default='gta')
generate_gta(parser.parse_args().output_dir)
28 changes: 28 additions & 0 deletions vietTTS/nat/zero_silence_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from argparse import ArgumentParser
from pathlib import Path

import numpy as np
from scipy.io import wavfile
from textgrid import TextGrid
from tqdm.auto import tqdm

from .config import FLAGS

parser = ArgumentParser()

parser.add_argument('-o', '--output-dir', type=Path, required=True)
args = parser.parse_args()

files = sorted(FLAGS.data_dir.glob('*.TextGrid'))
for fn in tqdm(files):
tg = TextGrid.fromFile(str(fn.resolve()))
wav_fn = FLAGS.data_dir / f'{fn.stem}.wav'
sr, y = wavfile.read(wav_fn)
y = np.copy(y)
for phone in tg[1]:
if phone.mark in FLAGS.special_phonemes:
l = int(phone.minTime * sr)
r = int(phone.maxTime * sr)
y[l:r] = 0
out_file = args.output_dir / f'{fn.stem}.wav'
wavfile.write(out_file, sr, y)

0 comments on commit 1006776

Please sign in to comment.