Skip to content

Commit

Permalink
[vits] remove weight norm
Browse files Browse the repository at this point in the history
  • Loading branch information
pengzhendong committed Sep 12, 2023
1 parent 74c0a01 commit 5192370
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 113 deletions.
23 changes: 5 additions & 18 deletions wetts/vits/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,8 @@ def main():
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

hps = utils.get_hparams_from_file(args.cfg)
with open(args.phone_table) as p_f:
phone_num = len(p_f.readlines()) + 1
num_speakers = 1
if args.speaker_table is not None:
num_speakers = len(open(args.speaker_table).readlines()) + 1
phone_num = len(open(args.phone_table).readlines())
num_speakers = len(open(args.speaker_table).readlines())

net_g = SynthesizerTrn(
phone_num,
Expand All @@ -67,6 +64,8 @@ def main():
**hps.model
)
utils.load_checkpoint(args.checkpoint, net_g, None)
net_g.flow.remove_weight_norm()
net_g.dec.remove_weight_norm()
net_g.forward = net_g.export_forward
net_g.eval()

Expand All @@ -75,7 +74,7 @@ def main():
scales = torch.FloatTensor([0.667, 1.0, 0.8])
# make triton dynamic shape happy
scales = scales.unsqueeze(0)
sid = torch.IntTensor([1]).long()
sid = torch.IntTensor([0]).long()

dummy_input = (seq, seq_len, scales, sid)
torch.onnx.export(
Expand All @@ -95,18 +94,6 @@ def main():
verbose=False,
)

# Verify onnx precision
torch_output = net_g(seq, seq_len, scales, sid)
providers = [args.providers]
ort_sess = ort.InferenceSession(args.onnx_model, providers=providers)
ort_inputs = {
"input": to_numpy(seq),
"input_lengths": to_numpy(seq_len),
"scales": to_numpy(scales),
"sid": to_numpy(sid),
}
onnx_output = ort_sess.run(None, ort_inputs)


if __name__ == "__main__":
main()
100 changes: 45 additions & 55 deletions wetts/vits/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_args():
parser.add_argument("--cfg", required=True, help="config file")
parser.add_argument("--outdir", required=True, help="ouput directory")
parser.add_argument("--phone_table", required=True, help="input phone dict")
parser.add_argument("--speaker_table", default=None, help="speaker table")
parser.add_argument("--speaker_table", default=True, help="speaker table")
parser.add_argument("--test_file", required=True, help="test file")
parser.add_argument(
"--gpu", type=int, default=-1, help="gpu id for this local rank, -1 for cpu"
Expand All @@ -50,21 +50,19 @@ def main():
device = torch.device("cuda" if use_cuda else "cpu")

phone_dict = {}
with open(args.phone_table) as p_f:
for line in p_f:
phone_id = line.strip().split()
phone_dict[phone_id[0]] = int(phone_id[1])
for line in open(args.phone_table):
phone_id = line.strip().split()
phone_dict[phone_id[0]] = int(phone_id[1])

speaker_dict = {}
if args.speaker_table is not None:
with open(args.speaker_table) as p_f:
for line in p_f:
arr = line.strip().split()
assert len(arr) == 2
speaker_dict[arr[0]] = int(arr[1])
for line in open(args.speaker_table):
arr = line.strip().split()
assert len(arr) == 2
speaker_dict[arr[0]] = int(arr[1])
hps = utils.get_hparams_from_file(args.cfg)

net_g = SynthesizerTrn(
len(phone_dict) + 1,
len(phone_dict),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=len(speaker_dict),
Expand All @@ -75,50 +73,42 @@ def main():
net_g.eval()
utils.load_checkpoint(args.checkpoint, net_g, None)

with open(args.test_file) as fin:
for line in fin:
arr = line.strip().split("|")
audio_path = arr[0]
if len(arr) == 2:
sid = 0
text = arr[1]
else:
sid = speaker_dict[arr[1]]
text = arr[2]
seq = [phone_dict[symbol] for symbol in text.split()]
seq = torch.LongTensor(seq)
print(audio_path)
with torch.no_grad():
x = seq.to(device).unsqueeze(0)
x_length = torch.LongTensor([seq.size(0)]).to(device)
sid = torch.LongTensor([sid]).to(device)
st = time.time()
audio = (
net_g.infer(
x,
x_length,
sid=sid,
noise_scale=0.667,
noise_scale_w=0.8,
length_scale=1,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
audio *= 32767 / max(0.01, np.max(np.abs(audio))) * 0.6
print(
"RTF {}".format(
(time.time() - st) / (audio.shape[0] / hps.data.sampling_rate)
)
)
sys.stdout.flush()
audio = np.clip(audio, -32767.0, 32767.0)
wavfile.write(
args.outdir + "/" + audio_path.split("/")[-1],
hps.data.sampling_rate,
audio.astype(np.int16),
for line in open(args.test_file):
audio_path, sid, text = line.strip().split("|")
seq = [phone_dict[symbol] for symbol in text.split()]
seq = torch.LongTensor(seq)
print(audio_path)
with torch.no_grad():
x = seq.to(device).unsqueeze(0)
x_length = torch.LongTensor([seq.size(0)]).to(device)
sid = torch.LongTensor([sid]).to(device)
st = time.time()
audio = (
net_g.infer(
x,
x_length,
sid=sid,
noise_scale=0.667,
noise_scale_w=0.8,
length_scale=1,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
audio *= 32767 / max(0.01, np.max(np.abs(audio))) * 0.6
print(
"RTF {}".format(
(time.time() - st) / (audio.shape[0] / hps.data.sampling_rate)
)
)
sys.stdout.flush()
audio = np.clip(audio, -32767.0, 32767.0)
wavfile.write(
args.outdir + "/" + audio_path.split("/")[-1],
hps.data.sampling_rate,
audio.astype(np.int16),
)


if __name__ == "__main__":
Expand Down
69 changes: 29 additions & 40 deletions wetts/vits/inference_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,63 +36,52 @@ def get_args():
parser.add_argument("--cfg", required=True, help="config file")
parser.add_argument("--outdir", required=True, help="ouput directory")
parser.add_argument("--phone_table", required=True, help="input phone dict")
parser.add_argument("--speaker_table", default=None, help="speaker table")
parser.add_argument("--speaker_table", default=True, help="speaker table")
parser.add_argument("--test_file", required=True, help="test file")
args = parser.parse_args()
return args


def main():
args = get_args()
print(args)
phone_dict = {}
with open(args.phone_table) as p_f:
for line in p_f:
phone_id = line.strip().split()
phone_dict[phone_id[0]] = int(phone_id[1])
for line in open(args.phone_table):
phone_id = line.strip().split()
phone_dict[phone_id[0]] = int(phone_id[1])

speaker_dict = {}
if args.speaker_table is not None:
with open(args.speaker_table) as p_f:
for line in p_f:
arr = line.strip().split()
assert len(arr) == 2
speaker_dict[arr[0]] = int(arr[1])
for line in open(args.speaker_table):
arr = line.strip().split()
assert len(arr) == 2
speaker_dict[arr[0]] = int(arr[1])
hps = utils.get_hparams_from_file(args.cfg)

ort_sess = ort.InferenceSession(args.onnx_model)
scales = torch.FloatTensor([0.667, 1.0, 0.8])
# make triton dynamic shape happy
scales = scales.unsqueeze(0)

with open(args.test_file) as fin:
for line in fin:
arr = line.strip().split("|")
audio_path = arr[0]
if len(arr) == 2:
sid = 0
text = arr[1]
else:
sid = speaker_dict[arr[1]]
text = arr[2]
seq = [phone_dict[symbol] for symbol in text.split()]
for line in open(args.test_file):
audio_path, sid, text = line.strip().split("|")
seq = [phone_dict[symbol] for symbol in text.split()]

x = torch.LongTensor([seq])
x_len = torch.IntTensor([x.size(1)]).long()
sid = torch.LongTensor([sid]).long()
ort_inputs = {
"input": to_numpy(x),
"input_lengths": to_numpy(x_len),
"scales": to_numpy(scales),
"sid": to_numpy(sid),
}
audio = np.squeeze(ort_sess.run(None, ort_inputs))
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
audio = np.clip(audio, -32767.0, 32767.0)
wavfile.write(
args.outdir + "/" + audio_path.split("/")[-1],
hps.data.sampling_rate,
audio.astype(np.int16),
)
x = torch.LongTensor([seq])
x_len = torch.IntTensor([x.size(1)]).long()
sid = torch.LongTensor([sid]).long()
ort_inputs = {
"input": to_numpy(x),
"input_lengths": to_numpy(x_len),
"scales": to_numpy(scales),
"sid": to_numpy(sid),
}
audio = np.squeeze(ort_sess.run(None, ort_inputs))
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
audio = np.clip(audio, -32767.0, 32767.0)
wavfile.write(
args.outdir + "/" + audio_path.split("/")[-1],
hps.data.sampling_rate,
audio.astype(np.int16),
)


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions wetts/vits/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ def forward(self, x, x_mask, g=None, reverse=False):
x = flow(x, x_mask, g=g, reverse=reverse)
return x

def remove_weight_norm(self):
for i, l in enumerate(self.flows):
if i % 2 == 0:
l.remove_weight_norm()


class PosteriorEncoder(nn.Module):
def __init__(
Expand Down
3 changes: 3 additions & 0 deletions wetts/vits/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ def forward(self, x, x_mask, g=None, reverse=False):
x = torch.cat([x0, x1], 1)
return x

def remove_weight_norm(self):
self.enc.remove_weight_norm()


class ConvFlow(nn.Module):
def __init__(
Expand Down

0 comments on commit 5192370

Please sign in to comment.