This repository has been archived by the owner on May 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
63 lines (48 loc) · 1.64 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
from scipy.io import wavfile
import torch
from tts_hw.datasets import ControlDataset
from tts_hw.model import FastSpeech
from tts_hw.utils import read_json, prepare_device
from tts_hw.utils.parse_config import ConfigParser
def main(config_path, checkpoint_path):
config = ConfigParser(read_json(config_path))
vocoder = Vocoder("./data/waveglow_256channels_universal_v5.pt").eval()
model = FastSpeech(**config["arch"]["args"])
device, device_ids = prepare_device(config["n_gpu"])
vocoder = vocoder.to(device)
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)
model.load_state_dict(torch.load(checkpoint_path)["state_dict"])
model.eval()
dataset = ControlDataset()
for i in range(len(dataset)):
transcript, tokens, _ = dataset[i]
tokens = tokens.to(device)
mel, _ = model.infer(tokens)
result = vocoder.inference(mel.transpose(-2, -1))
wavfile.write(
f"test{i}_{transcript[:32]}.wav", 22050, result.squeeze(0).cpu().numpy()
)
if __name__ == "__main__":
args = argparse.ArgumentParser(description="Test script")
args.add_argument(
"config",
type=str,
help="config file path",
)
args.add_argument(
"checkpoint",
type=str,
help="checkpoint file path",
)
args.add_argument(
"-o",
"--output",
default="./audio_samples",
type=str,
help="path to saved test audios (default: ./audio_samples)"
)
args = args.parse_args()
main(args["config"], args["checkpoint"])