diff --git a/dnn/torch/osce/adv_train_bwe_model.py b/dnn/torch/osce/adv_train_bwe_model.py index c230cd710..0d1681e2d 100644 --- a/dnn/torch/osce/adv_train_bwe_model.py +++ b/dnn/torch/osce/adv_train_bwe_model.py @@ -155,6 +155,7 @@ def preemph(x, gamma): lambda_feat = setup['training']['lambda_feat'] lambda_reg = setup['training']['lambda_reg'] adv_target = setup['training'].get('adv_target', 'x_48') +newloss = setup['training'].get('newloss', False) # load training dataset data_config = setup['data'] @@ -174,11 +175,13 @@ def preemph(x, gamma): model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs']) # create discriminator +print(setup['discriminator']['name'],setup['discriminator']['kwargs']) disc_name = setup['discriminator']['name'] disc = model_dict[disc_name]( *setup['discriminator']['args'], **setup['discriminator']['kwargs'] ) + # set compute device if type(args.device) == type(None): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -272,13 +275,28 @@ def td_l1(y_true, y_pred, pow=0): return torch.mean(tmp) -tdlp = TDLowpass(15, 4000/24000).to(device) +if newloss: + tdlp = TDLowpass(31, 4000/24000).to(device) +else: + tdlp = TDLowpass(15, 4000/24000).to(device) -def criterion(x, y, x_up): +if newloss: + def criterion(x, y, x_up): + # FD-losses are calculated on preemphasized signals + xp = preemph(x, preemph_gamma) + yp = preemph(y, preemph_gamma) - return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) - + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum + return (w_l1 * td_l1(x, y, pow=1) + stftloss(xp, yp) + w_logmel * logmelloss(xp, yp) + + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum +else: + def criterion(x, y, x_up): + # all losses are calculated on preemphasized signals + x = preemph(x, preemph_gamma) + y = preemph(y, preemph_gamma) + x_up = preemph(x_up, preemph_gamma) + return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) + + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum # model checkpoint checkpoint = { @@ -364,12 +382,10 @@ def optimizer_to(optim, device): # pre-emphasize disc_target = preemph(target, preemph_gamma) - target = preemph(target, preemph_gamma) - x_up = preemph(x_up, preemph_gamma) - output = preemph(output, preemph_gamma) + output_preemph = preemph(output, preemph_gamma) # discriminator update - scores_gen = disc(output.detach()) + scores_gen = disc(output_preemph.detach()) scores_real = disc(disc_target.unsqueeze(1)) disc_loss = 0 @@ -394,7 +410,7 @@ def optimizer_to(optim, device): optimizer_disc.step() # generator update - scores_gen = disc(output) + scores_gen = disc(output_preemph) # calculate loss loss_reg = criterion(target, output.squeeze(1), x_up) diff --git a/dnn/torch/osce/models/__init__.py b/dnn/torch/osce/models/__init__.py index b7ac4b3d9..cbc0d0be1 100644 --- a/dnn/torch/osce/models/__init__.py +++ b/dnn/torch/osce/models/__init__.py @@ -32,6 +32,7 @@ from .lavoce import LaVoce from .lavoce_400 import LaVoce400 from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc +from .td_discriminator import TDMultiResolutionDiscriminator as TDMResDisc from .bwe_net import BWENet from .bbwe_net import BBWENet @@ -41,6 +42,7 @@ 'lavoce': LaVoce, 'lavoce400': LaVoce400, 'fdmresdisc': FDMResDisc, + 'tdmresdisc': TDMResDisc, 'bwenet' : BWENet, 'bbwenet': BBWENet } diff --git a/dnn/torch/osce/models/bbwe_net.py b/dnn/torch/osce/models/bbwe_net.py index 39db9f76a..fb9b1f349 100644 --- a/dnn/torch/osce/models/bbwe_net.py +++ b/dnn/torch/osce/models/bbwe_net.py @@ -85,25 +85,25 @@ def forward(self, features, state=None): class Folder(torch.nn.Module): def __init__(self, num_taps, frame_size): super().__init__() - + self.num_taps = num_taps self.frame_size = frame_size assert frame_size % num_taps == 0 self.taps = torch.nn.Parameter(torch.randn(num_taps).view(1, 1, -1), requires_grad=True) - - + + def flop_count(self, rate): - + # single multiplication per sample return rate - + def forward(self, x, *args): - + batch_size, num_channels, length = x.shape assert length % self.num_taps == 0 - + y = x * torch.repeat_interleave(torch.exp(self.taps), length // self.num_taps, dim=-1) - + return y class BBWENet(torch.nn.Module): @@ -123,7 +123,7 @@ def __init__(self, interpolate_k48=1, shape_extension=True, func_extension=True, - shaper='TDShape', + shaper='TDShaper', bias=False, ): @@ -140,7 +140,7 @@ def __init__(self, self.shape_extension = shape_extension self.func_extension = func_extension self.shaper = shaper - + assert (shape_extension or func_extension) and "Require at least one of shape_extension and func_extension to be true" @@ -165,7 +165,7 @@ def __init__(self, self.tdshape2 = Folder(12, frame_size=self.frame_size48) else: raise ValueError(f"unknown shaper {self.shaper}") - + if activation == 'ImPowI': self.nlfunc = lambda x : x * torch.sin(torch.log(torch.abs(x) + 1e-6)) elif activation == "ReLU": @@ -209,7 +209,7 @@ def forward(self, x, features, debug=False): # split into latent_channels channels y16 = self.af1(x, cf, debug=debug) - + # first 2x upsampling step y32 = self.upsampler.hq_2x_up(y16) y32_out = y32[:, 0:1, :] # first channel is bypass channel @@ -220,14 +220,14 @@ def forward(self, x, features, debug=False): y32_shape = self.tdshape1(y32[:, idx:idx+1, :], cf) y32_out = torch.cat((y32_out, y32_shape), dim=1) idx += 1 - + if self.func_extension: y32_func = self.nlfunc(y32[:, idx:idx+1, :]) y32_out = torch.cat((y32_out, y32_func), dim=1) - + # mix-select y32_out = self.af2(y32_out, cf) - + # 1.5x upsampling y48 = self.upsampler.interpolate_3_2(y32_out) y48_out = y48[:, 0:1, :] # first channel is bypass channel @@ -238,12 +238,12 @@ def forward(self, x, features, debug=False): y48_shape = self.tdshape2(y48[:, idx:idx+1, :], cf) y48_out = torch.cat((y48_out, y48_shape), dim=1) idx += 1 - + if self.func_extension: y48_func = self.nlfunc(y48[:, idx:idx+1, :]) y48_out = torch.cat((y48_out, y48_func), dim=1) - + # 2nd mixing y48_out = self.af3(y48_out, cf) - + return y48_out \ No newline at end of file diff --git a/dnn/torch/osce/models/td_discriminator.py b/dnn/torch/osce/models/td_discriminator.py new file mode 100644 index 000000000..2136cbd3c --- /dev/null +++ b/dnn/torch/osce/models/td_discriminator.py @@ -0,0 +1,150 @@ +""" +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +# This is an adaptation of the HiFi-Gan discriminators derived from https://github.com/jik876/hifi-gan + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + +LRELU_SLOPE = 0.1 + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, max_channels=1024): + super(DiscriminatorP, self).__init__() + self.max_channels = max_channels + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(min(self.max_channels, 128), min(self.max_channels, 512), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(min(self.max_channels, 512), min(self.max_channels, 1024), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(min(self.max_channels, 1024), min(self.max_channels, 1024), (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(min(self.max_channels, 1024), 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + output = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + output.append(x) + x = self.conv_post(x) + output.append(x) + + return output + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, max_channels=1024): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2, max_channels=max_channels), + DiscriminatorP(3, max_channels=max_channels), + DiscriminatorP(5, max_channels=max_channels), + DiscriminatorP(7, max_channels=max_channels), + DiscriminatorP(11, max_channels=max_channels), + ]) + + def forward(self, y): + outputs = [] + for disc in self.discriminators: + outputs.append(disc(y)) + + return outputs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False, max_channels=1024): + super(DiscriminatorS, self).__init__() + self.max_channels = max_channels + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, min(self.max_channels, 128), 15, 1, padding=7)), + norm_f(Conv1d(min(self.max_channels, 128), min(self.max_channels, 128), 41, 2, groups=4, padding=20)), + norm_f(Conv1d(min(self.max_channels, 128), min(self.max_channels, 256), 41, 2, groups=16, padding=20)), + norm_f(Conv1d(min(self.max_channels, 256), min(self.max_channels, 512), 41, 4, groups=16, padding=20)), + norm_f(Conv1d(min(self.max_channels, 512), min(self.max_channels, 1024), 41, 4, groups=16, padding=20)), + norm_f(Conv1d(min(self.max_channels, 1024), min(self.max_channels, 1024), 41, 1, groups=16, padding=20)), + norm_f(Conv1d(min(self.max_channels, 1024), min(self.max_channels, 1024), 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(min(self.max_channels, 1024), 1, 3, 1, padding=1)) + + def forward(self, x): + output = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + output.append(x) + x = self.conv_post(x) + output.append(x) + + return output + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self, max_channels=1024): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True, max_channels=max_channels), + DiscriminatorS(max_channels=max_channels), + DiscriminatorS(max_channels=max_channels), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y): + outputs = [] + for disc in self.discriminators: + outputs.append(disc(y)) + + return outputs + + +class TDMultiResolutionDiscriminator(torch.nn.Module): + def __init__(self, max_channels=1024, **kwargs): + super().__init__() + print(f"{max_channels=}") + self.msd = MultiScaleDiscriminator(max_channels=max_channels) + self.mpd = MultiPeriodDiscriminator(max_channels=max_channels) + + def forward(self, y): + return self.msd(y) + self.mpd(y) \ No newline at end of file diff --git a/dnn/torch/osce/pre_to_adv.py b/dnn/torch/osce/pre_to_adv.py index e46488a6a..0437f35a6 100644 --- a/dnn/torch/osce/pre_to_adv.py +++ b/dnn/torch/osce/pre_to_adv.py @@ -24,7 +24,7 @@ setup['training'] = adv_setup['training'] setup['discriminator'] = adv_setup['discriminator'] - setup['data']['frames_per_sample'] = 60 + setup['data']['frames_per_sample'] = 90 with open(args.adv_setup_yaml, 'w') as f: yaml.dump(setup, f) diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py index 001aec8e5..8cf25e8f0 100644 --- a/dnn/torch/osce/utils/templates.py +++ b/dnn/torch/osce/utils/templates.py @@ -236,9 +236,9 @@ 'w_sc': 0, 'w_wsc': 0, 'w_xcorr': 0, - 'w_sxcorr': 1, - 'w_l2': 0, - 'w_slm': 2, + 'w_sxcorr': 2, + 'w_l2': 10, + 'w_slm': 1, 'w_tdlp': 1 }, 'preemph': 0.85