Skip to content

Commit

Permalink
more bwe stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
janpbuethe committed Sep 19, 2024
1 parent fc9871a commit ee29215
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 31 deletions.
34 changes: 25 additions & 9 deletions dnn/torch/osce/adv_train_bwe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def preemph(x, gamma):
lambda_feat = setup['training']['lambda_feat']
lambda_reg = setup['training']['lambda_reg']
adv_target = setup['training'].get('adv_target', 'x_48')
newloss = setup['training'].get('newloss', False)

# load training dataset
data_config = setup['data']
Expand All @@ -174,11 +175,13 @@ def preemph(x, gamma):
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])

# create discriminator
print(setup['discriminator']['name'],setup['discriminator']['kwargs'])
disc_name = setup['discriminator']['name']
disc = model_dict[disc_name](
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
)


# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -272,13 +275,28 @@ def td_l1(y_true, y_pred, pow=0):

return torch.mean(tmp)

tdlp = TDLowpass(15, 4000/24000).to(device)
if newloss:
tdlp = TDLowpass(31, 4000/24000).to(device)
else:
tdlp = TDLowpass(15, 4000/24000).to(device)

def criterion(x, y, x_up):
if newloss:
def criterion(x, y, x_up):
# FD-losses are calculated on preemphasized signals
xp = preemph(x, preemph_gamma)
yp = preemph(y, preemph_gamma)

return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum
return (w_l1 * td_l1(x, y, pow=1) + stftloss(xp, yp) + w_logmel * logmelloss(xp, yp)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum
else:
def criterion(x, y, x_up):
# all losses are calculated on preemphasized signals
x = preemph(x, preemph_gamma)
y = preemph(y, preemph_gamma)
x_up = preemph(x_up, preemph_gamma)

return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum

# model checkpoint
checkpoint = {
Expand Down Expand Up @@ -364,12 +382,10 @@ def optimizer_to(optim, device):

# pre-emphasize
disc_target = preemph(target, preemph_gamma)
target = preemph(target, preemph_gamma)
x_up = preemph(x_up, preemph_gamma)
output = preemph(output, preemph_gamma)
output_preemph = preemph(output, preemph_gamma)

# discriminator update
scores_gen = disc(output.detach())
scores_gen = disc(output_preemph.detach())
scores_real = disc(disc_target.unsqueeze(1))

disc_loss = 0
Expand All @@ -394,7 +410,7 @@ def optimizer_to(optim, device):
optimizer_disc.step()

# generator update
scores_gen = disc(output)
scores_gen = disc(output_preemph)

# calculate loss
loss_reg = criterion(target, output.squeeze(1), x_up)
Expand Down
2 changes: 2 additions & 0 deletions dnn/torch/osce/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .lavoce import LaVoce
from .lavoce_400 import LaVoce400
from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc
from .td_discriminator import TDMultiResolutionDiscriminator as TDMResDisc
from .bwe_net import BWENet
from .bbwe_net import BBWENet

Expand All @@ -41,6 +42,7 @@
'lavoce': LaVoce,
'lavoce400': LaVoce400,
'fdmresdisc': FDMResDisc,
'tdmresdisc': TDMResDisc,
'bwenet' : BWENet,
'bbwenet': BBWENet
}
36 changes: 18 additions & 18 deletions dnn/torch/osce/models/bbwe_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,25 @@ def forward(self, features, state=None):
class Folder(torch.nn.Module):
def __init__(self, num_taps, frame_size):
super().__init__()

self.num_taps = num_taps
self.frame_size = frame_size
assert frame_size % num_taps == 0
self.taps = torch.nn.Parameter(torch.randn(num_taps).view(1, 1, -1), requires_grad=True)


def flop_count(self, rate):

# single multiplication per sample
return rate

def forward(self, x, *args):

batch_size, num_channels, length = x.shape
assert length % self.num_taps == 0

y = x * torch.repeat_interleave(torch.exp(self.taps), length // self.num_taps, dim=-1)

return y

class BBWENet(torch.nn.Module):
Expand All @@ -123,7 +123,7 @@ def __init__(self,
interpolate_k48=1,
shape_extension=True,
func_extension=True,
shaper='TDShape',
shaper='TDShaper',
bias=False,
):

Expand All @@ -140,7 +140,7 @@ def __init__(self,
self.shape_extension = shape_extension
self.func_extension = func_extension
self.shaper = shaper

assert (shape_extension or func_extension) and "Require at least one of shape_extension and func_extension to be true"


Expand All @@ -165,7 +165,7 @@ def __init__(self,
self.tdshape2 = Folder(12, frame_size=self.frame_size48)
else:
raise ValueError(f"unknown shaper {self.shaper}")

if activation == 'ImPowI':
self.nlfunc = lambda x : x * torch.sin(torch.log(torch.abs(x) + 1e-6))
elif activation == "ReLU":
Expand Down Expand Up @@ -209,7 +209,7 @@ def forward(self, x, features, debug=False):

# split into latent_channels channels
y16 = self.af1(x, cf, debug=debug)

# first 2x upsampling step
y32 = self.upsampler.hq_2x_up(y16)
y32_out = y32[:, 0:1, :] # first channel is bypass channel
Expand All @@ -220,14 +220,14 @@ def forward(self, x, features, debug=False):
y32_shape = self.tdshape1(y32[:, idx:idx+1, :], cf)
y32_out = torch.cat((y32_out, y32_shape), dim=1)
idx += 1

if self.func_extension:
y32_func = self.nlfunc(y32[:, idx:idx+1, :])
y32_out = torch.cat((y32_out, y32_func), dim=1)

# mix-select
y32_out = self.af2(y32_out, cf)

# 1.5x upsampling
y48 = self.upsampler.interpolate_3_2(y32_out)
y48_out = y48[:, 0:1, :] # first channel is bypass channel
Expand All @@ -238,12 +238,12 @@ def forward(self, x, features, debug=False):
y48_shape = self.tdshape2(y48[:, idx:idx+1, :], cf)
y48_out = torch.cat((y48_out, y48_shape), dim=1)
idx += 1

if self.func_extension:
y48_func = self.nlfunc(y48[:, idx:idx+1, :])
y48_out = torch.cat((y48_out, y48_func), dim=1)

# 2nd mixing
y48_out = self.af3(y48_out, cf)

return y48_out
150 changes: 150 additions & 0 deletions dnn/torch/osce/models/td_discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
MIT License
Copyright (c) 2020 Jungil Kong
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

# This is an adaptation of the HiFi-Gan discriminators derived from https://github.com/jik876/hifi-gan

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm

def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)

LRELU_SLOPE = 0.1

class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, max_channels=1024):
super(DiscriminatorP, self).__init__()
self.max_channels = max_channels
self.period = period
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(min(self.max_channels, 128), min(self.max_channels, 512), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(min(self.max_channels, 512), min(self.max_channels, 1024), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(min(self.max_channels, 1024), min(self.max_channels, 1024), (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(Conv2d(min(self.max_channels, 1024), 1, (3, 1), 1, padding=(1, 0)))

def forward(self, x):

# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)

output = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
output.append(x)
x = self.conv_post(x)
output.append(x)

return output


class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, max_channels=1024):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(2, max_channels=max_channels),
DiscriminatorP(3, max_channels=max_channels),
DiscriminatorP(5, max_channels=max_channels),
DiscriminatorP(7, max_channels=max_channels),
DiscriminatorP(11, max_channels=max_channels),
])

def forward(self, y):
outputs = []
for disc in self.discriminators:
outputs.append(disc(y))

return outputs


class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False, max_channels=1024):
super(DiscriminatorS, self).__init__()
self.max_channels = max_channels
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, min(self.max_channels, 128), 15, 1, padding=7)),
norm_f(Conv1d(min(self.max_channels, 128), min(self.max_channels, 128), 41, 2, groups=4, padding=20)),
norm_f(Conv1d(min(self.max_channels, 128), min(self.max_channels, 256), 41, 2, groups=16, padding=20)),
norm_f(Conv1d(min(self.max_channels, 256), min(self.max_channels, 512), 41, 4, groups=16, padding=20)),
norm_f(Conv1d(min(self.max_channels, 512), min(self.max_channels, 1024), 41, 4, groups=16, padding=20)),
norm_f(Conv1d(min(self.max_channels, 1024), min(self.max_channels, 1024), 41, 1, groups=16, padding=20)),
norm_f(Conv1d(min(self.max_channels, 1024), min(self.max_channels, 1024), 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(min(self.max_channels, 1024), 1, 3, 1, padding=1))

def forward(self, x):
output = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
output.append(x)
x = self.conv_post(x)
output.append(x)

return output


class MultiScaleDiscriminator(torch.nn.Module):
def __init__(self, max_channels=1024):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True, max_channels=max_channels),
DiscriminatorS(max_channels=max_channels),
DiscriminatorS(max_channels=max_channels),
])
self.meanpools = nn.ModuleList([
AvgPool1d(4, 2, padding=2),
AvgPool1d(4, 2, padding=2)
])

def forward(self, y):
outputs = []
for disc in self.discriminators:
outputs.append(disc(y))

return outputs


class TDMultiResolutionDiscriminator(torch.nn.Module):
def __init__(self, max_channels=1024, **kwargs):
super().__init__()
print(f"{max_channels=}")
self.msd = MultiScaleDiscriminator(max_channels=max_channels)
self.mpd = MultiPeriodDiscriminator(max_channels=max_channels)

def forward(self, y):
return self.msd(y) + self.mpd(y)
2 changes: 1 addition & 1 deletion dnn/torch/osce/pre_to_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

setup['training'] = adv_setup['training']
setup['discriminator'] = adv_setup['discriminator']
setup['data']['frames_per_sample'] = 60
setup['data']['frames_per_sample'] = 90

with open(args.adv_setup_yaml, 'w') as f:
yaml.dump(setup, f)
6 changes: 3 additions & 3 deletions dnn/torch/osce/utils/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@
'w_sc': 0,
'w_wsc': 0,
'w_xcorr': 0,
'w_sxcorr': 1,
'w_l2': 0,
'w_slm': 2,
'w_sxcorr': 2,
'w_l2': 10,
'w_slm': 1,
'w_tdlp': 1
},
'preemph': 0.85
Expand Down

0 comments on commit ee29215

Please sign in to comment.