Skip to content
This repository has been archived by the owner on Sep 11, 2022. It is now read-only.

Commit

Permalink
add with no grad when inference
Browse files Browse the repository at this point in the history
  • Loading branch information
yt605155624 committed Oct 12, 2021
1 parent 4b94af9 commit aacf81c
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def evaluate(args, config):
# extract mel feats
mel = mel_extractor.get_log_mel_fbank(wav)
mel = paddle.to_tensor(mel)
gen_wav = pwg_inference(mel)
with paddle.no_grad():
gen_wav = pwg_inference(mel)
sf.write(
str(output_dir / ("gen_" + utt_name)),
gen_wav.numpy(),
Expand Down
3 changes: 2 additions & 1 deletion examples/GANVocoder/parallelwave_gan/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def main():
mel = example['feats']
mel = paddle.to_tensor(mel) # (T, C)
with timer() as t:
wav = generator.inference(c=mel)
with paddle.no_grad():
wav = generator.inference(c=mel)
wav = wav.numpy()
N += wav.size
T += t.elapse
Expand Down
10 changes: 5 additions & 5 deletions examples/fastspeech2/aishell3/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ def evaluate(args, fastspeech2_config, pwg_config):
mel = fastspeech2_inference(
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
temp_wav = pwg_inference(mel)
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
sf.write(
str(output_dir / (str(spk_id) + "_" + utt_id + ".wav")),
wav.numpy(),
Expand Down
10 changes: 5 additions & 5 deletions examples/fastspeech2/baker/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def evaluate(args, fastspeech2_config, pwg_config):
with paddle.no_grad():
mel = fastspeech2_inference(part_phone_ids)
temp_wav = pwg_inference(mel)
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
Expand Down
10 changes: 5 additions & 5 deletions examples/speedyspeech/baker/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def evaluate(args, speedyspeech_config, pwg_config):
with paddle.no_grad():
mel = speedyspeech_inference(part_phone_ids, part_tone_ids)
temp_wav = pwg_inference(mel)
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
sf.write(
output_dir / (utt_id + ".wav"),
wav.numpy(),
Expand Down
11 changes: 6 additions & 5 deletions examples/transformer_tts/ljspeech/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ def evaluate(args, acoustic_model_config, vocoder_config):
phones = [phn for phn in phones if not phn.isspace()]
phones = [phn if phn in phone_id_map else "," for phn in phones]
phone_ids = [phone_id_map[phn] for phn in phones]
mel = transformer_tts_inference(paddle.to_tensor(phone_ids))
# mel shape is (T, feats) and waveflow's input shape is (batch, feats, T)
mel = mel.unsqueeze(0).transpose([0, 2, 1])
# wavflow's output shape is (B, T)
wav = vocoder.infer(mel)[0]
with paddle.no_grad():
mel = transformer_tts_inference(paddle.to_tensor(phone_ids))
# mel shape is (T, feats) and waveflow's input shape is (batch, feats, T)
mel = mel.unsqueeze(0).transpose([0, 2, 1])
# wavflow's output shape is (B, T)
wav = vocoder.infer(mel)[0]

sf.write(
str(output_dir / (utt_id + ".wav")),
Expand Down
1 change: 0 additions & 1 deletion examples/transformer_tts/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def evaluate(args, acoustic_model_config, vocoder_config):
mel = mel.unsqueeze(0).transpose([0, 2, 1])
# wavflow's output shape is (B, T)
wav = vocoder.infer(mel)[0]
print("wav:", wav)

sf.write(
str(output_dir / (utt_id + ".wav")),
Expand Down

0 comments on commit aacf81c

Please sign in to comment.