Skip to content

Commit

Permalink
[fix] export onnx error
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackiexiao committed Nov 21, 2023
1 parent 01a9acb commit 274242c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
11 changes: 10 additions & 1 deletion wetts/vits/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,17 @@ def main():
phone_num = len(open(args.phone_table).readlines())
num_speakers = len(open(args.speaker_table).readlines())

if ("use_mel_posterior_encoder" in hps.model.keys()
and hps.model.use_mel_posterior_encoder):
print("Using mel posterior encoder for VITS2")
posterior_channels = 80 # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1

net_g = SynthesizerTrn(phone_num,
hps.data.filter_length // 2 + 1,
posterior_channels,
hps.train.segment_size // hps.data.hop_length,
n_speakers=num_speakers,
**hps.model)
Expand Down
6 changes: 6 additions & 0 deletions wetts/vits/model/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def forward(self, x, x_mask, g=None, reverse=False):
x = torch.cat([x0, x1], 1)
return x

def remove_weight_norm(self):
self.enc.remove_weight_norm()


class ResidualCouplingTransformersLayer(nn.Module): # vits2

Expand Down Expand Up @@ -169,6 +172,9 @@ def forward(self, x, x_mask, g=None, reverse=False):
x = torch.cat([x0, x1], 1)
return x

def remove_weight_norm(self):
self.enc.remove_weight_norm()


class FFTransformerCouplingLayer(nn.Module): # vits2

Expand Down
13 changes: 7 additions & 6 deletions wetts/vits/model/modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations

from utils import commons

Expand Down Expand Up @@ -32,7 +33,7 @@ def __init__(
self.drop = nn.Dropout(p_dropout)

cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = weight_norm(cond_layer, name="weight")
self.cond_layer = weight_norm(cond_layer)

for i in range(n_layers):
dilation = dilation_rate**i
Expand All @@ -44,7 +45,7 @@ def __init__(
dilation=dilation,
padding=padding,
)
in_layer = weight_norm(in_layer, name="weight")
in_layer = weight_norm(in_layer)
self.in_layers.append(in_layer)

# last one is not necessary
Expand All @@ -54,7 +55,7 @@ def __init__(
res_skip_channels = hidden_channels

res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name="weight")
res_skip_layer = weight_norm(res_skip_layer)
self.res_skip_layers.append(res_skip_layer)

def forward(self, x, x_mask, g=None, **kwargs):
Expand Down Expand Up @@ -88,11 +89,11 @@ def forward(self, x, x_mask, g=None, **kwargs):

def remove_weight_norm(self):
if self.gin_channels != 0:
nn.utils.remove_weight_norm(self.cond_layer)
remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")


class Flip(nn.Module):
Expand Down

0 comments on commit 274242c

Please sign in to comment.