Skip to content
This repository has been archived by the owner on Sep 11, 2022. It is now read-only.

Commit

Permalink
Merge pull request #185 from yt605155624/new_stft
Browse files Browse the repository at this point in the history
update stft loss and pwgan synthesize
  • Loading branch information
yt605155624 authored Sep 29, 2021
2 parents 78648cc + d547e4c commit bc36816
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 64 deletions.
33 changes: 17 additions & 16 deletions examples/fastspeech2/aishell3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ fastspeech2_nosil_aishell3_ckpt_0.4
```
## Synthesize
We use [parallel wavegan](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/parallelwave_gan/baker) as the neural vocoder.
Download pretrained parallel wavegan model (Trained with baker) from [fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_aishell3_ckpt_0.4.zip) and unzip it.
Download pretrained parallel wavegan model from [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip) and unzip it.
```bash
unzip parallel_wavegan_baker_ckpt_0.4.zip
unzip pwg_baker_ckpt_0.4.zip
```
Parallel WaveGAN checkpoint contains files listed below.
```text
parallel_wavegan_baker_ckpt_0.4
├── pwg_default.yaml # default config used to train parallel wavegan
├── pwg_generator.pdparams # model parameters of parallel wavegan
└── pwg_stats.npy # statistics used to normalize spectrogram when training parallel wavegan
pwg_baker_ckpt_0.4
├── pwg_default.yaml # default config used to train parallel wavegan
├── pwg_snapshot_iter_400000.pdz # model parameters of parallel wavegan
└── pwg_stats.npy # statistics used to normalize spectrogram when training parallel wavegan
```
`synthesize.sh` calls `synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
Expand All @@ -118,9 +118,9 @@ parallel_wavegan_baker_ckpt_0.4
usage: synthesize.py [-h] [--fastspeech2-config FASTSPEECH2_CONFIG]
[--fastspeech2-checkpoint FASTSPEECH2_CHECKPOINT]
[--fastspeech2-stat FASTSPEECH2_STAT]
[--pwg-config PWG_CONFIG] [--pwg-params PWG_PARAMS]
[--pwg-stat PWG_STAT] [--phones-dict PHONES_DICT]
[--speaker-dict SPEAKER_DICT]
[--pwg-config PWG_CONFIG]
[--pwg-checkpoint PWG_CHECKPOINT] [--pwg-stat PWG_STAT]
[--phones-dict PHONES_DICT] [--speaker-dict SPEAKER_DICT]
[--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR]
[--device DEVICE] [--verbose VERBOSE]
Expand All @@ -137,7 +137,7 @@ optional arguments:
spectrogram when training fastspeech2.
--pwg-config PWG_CONFIG
parallel wavegan config file.
--pwg-params PWG_PARAMS
--pwg-checkpoint PWG_CHECKPOINT
parallel wavegan generator parameters to load.
--pwg-stat PWG_STAT mean and standard deviation used to normalize
spectrogram when training parallel wavegan.
Expand All @@ -162,7 +162,8 @@ optional arguments:
usage: synthesize_e2e.py [-h] [--fastspeech2-config FASTSPEECH2_CONFIG]
[--fastspeech2-checkpoint FASTSPEECH2_CHECKPOINT]
[--fastspeech2-stat FASTSPEECH2_STAT]
[--pwg-config PWG_CONFIG] [--pwg-params PWG_PARAMS]
[--pwg-config PWG_CONFIG]
[--pwg-checkpoint PWG_CHECKPOINT]
[--pwg-stat PWG_STAT] [--phones-dict PHONES_DICT]
[--speaker-dict SPEAKER_DICT] [--text TEXT]
[--output-dir OUTPUT_DIR] [--device DEVICE]
Expand All @@ -181,7 +182,7 @@ optional arguments:
spectrogram when training fastspeech2.
--pwg-config PWG_CONFIG
parallel wavegan config file.
--pwg-params PWG_PARAMS
--pwg-checkpoint PWG_CHECKPOINT
parallel wavegan generator parameters to load.
--pwg-stat PWG_STAT mean and standard deviation used to normalize
spectrogram when training parallel wavegan.
Expand All @@ -196,7 +197,7 @@ optional arguments:
--verbose VERBOSE verbose.
```
1. `--fastspeech2-config`, `--fastspeech2-checkpoint`, `--fastspeech2-stat`, `--phones-dict` and `--speaker-dict` are arguments for fastspeech2, which correspond to the 5 files in the fastspeech2 pretrained model.
2. `--pwg-config`, `--pwg-params`, `--pwg-stat` are arguments for parallel wavegan, which correspond to the 3 files in the parallel wavegan pretrained model.
2. `--pwg-config`, `--pwg-checkpoint`, `--pwg-stat` are arguments for parallel wavegan, which correspond to the 3 files in the parallel wavegan pretrained model.
3. `--test-metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
4. `--text` is the text file, which contains sentences to synthesize.
5. `--output-dir` is the directory to save synthesized audio files.
Expand All @@ -208,9 +209,9 @@ python3 synthesize_e2e.py \
--fastspeech2-config=fastspeech2_nosil_aishell3_ckpt_0.4/default.yaml \
--fastspeech2-checkpoint=fastspeech2_nosil_aishell3_ckpt_0.4/snapshot_iter_96400.pdz \
--fastspeech2-stat=fastspeech2_nosil_aishell3_ckpt_0.4/speech_stats.npy \
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
--pwg-config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--pwg-checkpoint=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=../sentences.txt \
--output-dir=exp/default/test_e2e \
--device="gpu" \
Expand Down
4 changes: 2 additions & 2 deletions examples/fastspeech2/aishell3/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
model.eval()

vocoder = PWGGenerator(**pwg_config["generator_params"])
vocoder.set_state_dict(paddle.load(args.pwg_params))
vocoder.set_state_dict(paddle.load(args.pwg_checkpoint)["generator_params"])
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
Expand Down Expand Up @@ -117,7 +117,7 @@ def main():
parser.add_argument(
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-params",
"--pwg-checkpoint",
type=str,
help="parallel wavegan generator parameters to load.")
parser.add_argument(
Expand Down
6 changes: 3 additions & 3 deletions examples/fastspeech2/aishell3/synthesize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ python3 synthesize.py \
--fastspeech2-config=conf/default.yaml \
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_96400.pdz \
--fastspeech2-stat=dump/train/speech_stats.npy \
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
--pwg-config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--pwg-checkpoint=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=exp/default/test \
--device="gpu" \
Expand Down
4 changes: 2 additions & 2 deletions examples/fastspeech2/aishell3/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
model.eval()

vocoder = PWGGenerator(**pwg_config["generator_params"])
vocoder.set_state_dict(paddle.load(args.pwg_params))
vocoder.set_state_dict(paddle.load(args.pwg_checkpoint)["generator_params"])
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
Expand Down Expand Up @@ -128,7 +128,7 @@ def main():
parser.add_argument(
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-params",
"--pwg-checkpoint",
type=str,
help="parallel wavegan generator parameters to load.")
parser.add_argument(
Expand Down
7 changes: 3 additions & 4 deletions examples/fastspeech2/aishell3/synthesize_e2e.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@

#!/bin/bash

python3 synthesize_e2e.py \
--fastspeech2-config=conf/default.yaml \
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_96400.pdz \
--fastspeech2-stat=dump/train/speech_stats.npy \
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
--pwg-config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--pwg-checkpoint=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=../sentences.txt \
--output-dir=exp/default/test_e2e \
--device="gpu" \
Expand Down
32 changes: 17 additions & 15 deletions examples/fastspeech2/baker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ fastspeech2_nosil_baker_ckpt_0.4
```
## Synthesize
We use [parallel wavegan](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/parallelwave_gan/baker) as the neural vocoder.
Download pretrained parallel wavegan model from [parallel_wavegan_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/parallel_wavegan_baker_ckpt_0.4.zip) and unzip it.
Download pretrained parallel wavegan model from [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip) and unzip it.
```bash
unzip parallel_wavegan_baker_ckpt_0.4.zip
unzip pwg_baker_ckpt_0.4.zip
```
Parallel WaveGAN checkpoint contains files listed below.
```text
parallel_wavegan_baker_ckpt_0.4
├── pwg_default.yaml # default config used to train parallel wavegan
├── pwg_generator.pdparams # model parameters of parallel wavegan
└── pwg_stats.npy # statistics used to normalize spectrogram when training parallel wavegan
pwg_baker_ckpt_0.4
├── pwg_default.yaml # default config used to train parallel wavegan
├── pwg_snapshot_iter_400000.pdz # model parameters of parallel wavegan
└── pwg_stats.npy # statistics used to normalize spectrogram when training parallel wavegan
```
`synthesize.sh` calls `synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
Expand All @@ -107,8 +107,9 @@ parallel_wavegan_baker_ckpt_0.4
usage: synthesize.py [-h] [--fastspeech2-config FASTSPEECH2_CONFIG]
[--fastspeech2-checkpoint FASTSPEECH2_CHECKPOINT]
[--fastspeech2-stat FASTSPEECH2_STAT]
[--pwg-config PWG_CONFIG] [--pwg-params PWG_PARAMS]
[--pwg-stat PWG_STAT] [--phones-dict PHONES_DICT]
[--pwg-config PWG_CONFIG]
[--pwg-checkpoint PWG_CHECKPOINT] [--pwg-stat PWG_STAT]
[--phones-dict PHONES_DICT]
[--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR]
[--device DEVICE] [--verbose VERBOSE]
Expand All @@ -125,7 +126,7 @@ optional arguments:
spectrogram when training fastspeech2.
--pwg-config PWG_CONFIG
parallel wavegan config file.
--pwg-params PWG_PARAMS
--pwg-checkpoint PWG_CHECKPOINT
parallel wavegan generator parameters to load.
--pwg-stat PWG_STAT mean and standard deviation used to normalize
spectrogram when training parallel wavegan.
Expand All @@ -146,7 +147,8 @@ optional arguments:
usage: synthesize_e2e.py [-h] [--fastspeech2-config FASTSPEECH2_CONFIG]
[--fastspeech2-checkpoint FASTSPEECH2_CHECKPOINT]
[--fastspeech2-stat FASTSPEECH2_STAT]
[--pwg-config PWG_CONFIG] [--pwg-params PWG_PARAMS]
[--pwg-config PWG_CONFIG]
[--pwg-checkpoint PWG_CHECKPOINT]
[--pwg-stat PWG_STAT] [--phones-dict PHONES_DICT]
[--text TEXT] [--output-dir OUTPUT_DIR]
[--device DEVICE] [--verbose VERBOSE]
Expand All @@ -164,7 +166,7 @@ optional arguments:
spectrogram when training fastspeech2.
--pwg-config PWG_CONFIG
parallel wavegan config file.
--pwg-params PWG_PARAMS
--pwg-checkpoint PWG_CHECKPOINT
parallel wavegan generator parameters to load.
--pwg-stat PWG_STAT mean and standard deviation used to normalize
spectrogram when training parallel wavegan.
Expand All @@ -178,7 +180,7 @@ optional arguments:
```

1. `--fastspeech2-config`, `--fastspeech2-checkpoint`, `--fastspeech2-stat` and `--phones-dict` are arguments for fastspeech2, which correspond to the 4 files in the fastspeech2 pretrained model.
2. `--pwg-config`, `--pwg-params`, `--pwg-stat` are arguments for parallel wavegan, which correspond to the 3 files in the parallel wavegan pretrained model.
2. `--pwg-config`, `--pwg-checkpoint`, `--pwg-stat` are arguments for parallel wavegan, which correspond to the 3 files in the parallel wavegan pretrained model.
3. `--test-metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
4. `--text` is the text file, which contains sentences to synthesize.
5. `--output-dir` is the directory to save synthesized audio files.
Expand All @@ -190,9 +192,9 @@ python3 synthesize_e2e.py \
--fastspeech2-config=fastspeech2_nosil_baker_ckpt_0.4/default.yaml \
--fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_0.4/snapshot_iter_76000.pdz \
--fastspeech2-stat=fastspeech2_nosil_baker_ckpt_0.4/speech_stats.npy \
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
--pwg-config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--pwg-checkpoint=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=../sentences.txt \
--output-dir=exp/default/test_e2e \
--device="gpu" \
Expand Down
4 changes: 2 additions & 2 deletions examples/fastspeech2/baker/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
model.eval()

vocoder = PWGGenerator(**pwg_config["generator_params"])
vocoder.set_state_dict(paddle.load(args.pwg_params))
vocoder.set_state_dict(paddle.load(args.pwg_checkpoint)["generator_params"])
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
Expand Down Expand Up @@ -106,7 +106,7 @@ def main():
parser.add_argument(
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-params",
"--pwg-checkpoint",
type=str,
help="parallel wavegan generator parameters to load.")
parser.add_argument(
Expand Down
6 changes: 3 additions & 3 deletions examples/fastspeech2/baker/synthesize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ python3 synthesize.py \
--fastspeech2-config=conf/default.yaml \
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_153.pdz \
--fastspeech2-stat=dump/train/speech_stats.npy \
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
--pwg-config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--pwg-checkpoint=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=exp/default/test \
--device="gpu" \
Expand Down
4 changes: 2 additions & 2 deletions examples/fastspeech2/baker/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
model.eval()

vocoder = PWGGenerator(**pwg_config["generator_params"])
vocoder.set_state_dict(paddle.load(args.pwg_params))
vocoder.set_state_dict(paddle.load(args.pwg_checkpoint)["generator_params"])
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
Expand Down Expand Up @@ -118,7 +118,7 @@ def main():
parser.add_argument(
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-params",
"--pwg-checkpoint",
type=str,
help="parallel wavegan generator parameters to load.")
parser.add_argument(
Expand Down
6 changes: 3 additions & 3 deletions examples/fastspeech2/baker/synthesize_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ python3 synthesize_e2e.py \
--fastspeech2-config=conf/default.yaml \
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_153.pdz \
--fastspeech2-stat=dump/train/speech_stats.npy \
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
--pwg-config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--pwg-checkpoint=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=../sentences.txt \
--output-dir=exp/default/test_e2e \
--device="gpu" \
Expand Down
69 changes: 57 additions & 12 deletions parakeet/modules/stft_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,57 @@
import paddle
from paddle import nn
from paddle.nn import functional as F

from parakeet.modules.audio import STFT
from scipy import signal


def stft(x,
fft_size,
hop_length=None,
win_length=None,
window='hann',
center=True,
pad_mode='reflect'):
"""Perform STFT and convert to magnitude spectrogram.
Parameters
----------
x : Tensor
Input signal tensor (B, T).
fft_size : int
FFT size.
hop_size : int
Hop size.
win_length : int
window : str, optional
window : str
Name of window function, see `scipy.signal.get_window` for more
details. Defaults to "hann".
center : bool, optional
center (bool, optional): Whether to pad `x` to make that the
:math:`t \times hop\_length` at the center of :math:`t`-th frame. Default: `True`.
pad_mode : str, optional
Choose padding pattern when `center` is `True`.
Returns
----------
Tensor:
Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
# calculate window
window = signal.get_window(window, win_length, fftbins=True)
window = paddle.to_tensor(window)
x_stft = paddle.tensor.signal.stft(
x,
fft_size,
hop_length,
win_length,
window=window,
center=center,
pad_mode=pad_mode)

real = x_stft.real()
imag = x_stft.imag()

return paddle.sqrt(paddle.clip(real**2 + imag**2, min=1e-7)).transpose(
[0, 2, 1])


class SpectralConvergenceLoss(nn.Layer):
Expand Down Expand Up @@ -46,7 +95,7 @@ def forward(self, x_mag, y_mag):
class LogSTFTMagnitudeLoss(nn.Layer):
"""Log STFT magnitude loss module."""

def __init__(self, epsilon=1e-10):
def __init__(self, epsilon=1e-7):
"""Initilize los STFT magnitude loss module."""
super().__init__()
self.epsilon = epsilon
Expand Down Expand Up @@ -82,11 +131,7 @@ def __init__(self,
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.stft = STFT(
n_fft=fft_size,
hop_length=shift_size,
win_length=win_length,
window=window)
self.window = window
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()

Expand All @@ -105,10 +150,10 @@ def forward(self, x, y):
Tensor
Log STFT magnitude loss value.
"""
x_mag = self.stft.magnitude(x)
y_mag = self.stft.magnitude(y)
x_mag = x_mag.transpose([0, 2, 1])
y_mag = y_mag.transpose([0, 2, 1])
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
self.window)
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)

Expand Down

0 comments on commit bc36816

Please sign in to comment.