This repository has been archived by the owner on Aug 6, 2024. It is now read-only.
forked from Rubiksman78/MonikA.I
-
Notifications
You must be signed in to change notification settings - Fork 0
/
new_tts_infer.py
56 lines (44 loc) · 1.94 KB
/
new_tts_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from pathlib import Path
def infer(spec_gen_model, vocoder_model, str_input, speaker=None):
"""
Synthesizes spectrogram and audio from a text string given a spectrogram synthesis and vocoder model.
Args:
spec_gen_model: Spectrogram generator model (FastPitch in our case)
vocoder_model: Vocoder model (HiFiGAN in our case)
str_input: Text input for the synthesis
speaker: Speaker ID
Returns:
spectrogram and waveform of the synthesized audio.
"""
with torch.no_grad():
parsed = spec_gen_model.parse(str_input)
if speaker is not None:
speaker = torch.tensor([speaker]).long().to(device=spec_gen_model.device)
spectrogram = spec_gen_model.generate_spectrogram(tokens=parsed, speaker=speaker)
audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)
if spectrogram is not None:
if isinstance(spectrogram, torch.Tensor):
spectrogram = spectrogram.to('cpu').numpy()
if len(spectrogram.shape) == 3:
spectrogram = spectrogram[0]
if isinstance(audio, torch.Tensor):
audio = audio.to('cpu').numpy()
return spectrogram, audio
def get_best_ckpt_from_last_run(
base_dir,
new_speaker_id,
duration_mins,
mixing_enabled,
original_speaker_id,
model_name="FastPitch"
):
mixing = "no_mixing" if not mixing_enabled else "mixing"
d = f"{original_speaker_id}_to_{new_speaker_id}_{mixing}_{duration_mins}_mins"
exp_dirs = list([i for i in (Path(base_dir) / d / model_name).iterdir() if i.is_dir()])
last_exp_dir = sorted(exp_dirs)[-1]
last_checkpoint_dir = last_exp_dir / "checkpoints"
last_ckpt = list(last_checkpoint_dir.glob('*-last.ckpt'))
if len(last_ckpt) == 0:
raise ValueError(f"There is no last checkpoint in {last_checkpoint_dir}.")
return str(last_ckpt[0])