From 4fe9cde6df667b5315eed5db9ab19fc25e89b484 Mon Sep 17 00:00:00 2001 From: Jan Buethe Date: Sun, 15 Sep 2024 11:48:22 +0200 Subject: [PATCH] fixed some things that got broken --- .../osce/data/lpcnet_vocoding_dataset.py | 30 +++++++++++++++++++ dnn/torch/osce/losses/td_lowpass.py | 25 ++++++++-------- dnn/torch/osce/models/lavoce.py | 6 ++-- dnn/torch/osce/silk_16_to_48.py | 10 +++---- dnn/torch/osce/train_vocoder.py | 2 +- dnn/torch/osce/utils/layers/fir.py | 14 ++++----- dnn/torch/osce/utils/layers/td_shaper.py | 16 ++++++---- dnn/torch/osce/utils/lpcnet_features.py | 9 ++++-- 8 files changed, 76 insertions(+), 36 deletions(-) diff --git a/dnn/torch/osce/data/lpcnet_vocoding_dataset.py b/dnn/torch/osce/data/lpcnet_vocoding_dataset.py index 36c8c724e..f2f1a6db8 100644 --- a/dnn/torch/osce/data/lpcnet_vocoding_dataset.py +++ b/dnn/torch/osce/data/lpcnet_vocoding_dataset.py @@ -86,6 +86,8 @@ def __init__(self, self.getitem = self.getitem_v1 elif self.version == 2: self.getitem = self.getitem_v2 + elif self.version == 3: + self.getitem = self.getitem_v3 else: raise ValueError(f"dataset version {self.version} unknown") @@ -125,6 +127,34 @@ def __init__(self, def __getitem__(self, index): return self.getitem(index) + def getitem_v3(self, index): + sample = dict() + + # extract features + frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history + frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead + + for feature in self.input_features: + feature_start, feature_stop = self.feature_frame_layout[feature] + sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] + + # convert periods + if 'periods' in self.input_features: + sample['periods'] = np.round(np.clip(256./2**(sample['periods']+1.5), 32, 255)).astype('int16') + + signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length + signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length + + sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']] + + # concatenate features + feature_keys = [key for key in self.input_features if not key.startswith("periods")] + features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) + target = torch.FloatTensor(sample[self.target]) / 2**15 + periods = torch.LongTensor(sample['periods']) + + return {'features' : features, 'periods' : periods, 'target' : target} + def getitem_v2(self, index): sample = dict() diff --git a/dnn/torch/osce/losses/td_lowpass.py b/dnn/torch/osce/losses/td_lowpass.py index af422fb55..f890df2e4 100644 --- a/dnn/torch/osce/losses/td_lowpass.py +++ b/dnn/torch/osce/losses/td_lowpass.py @@ -7,28 +7,27 @@ class TDLowpass(torch.nn.Module): def __init__(self, numtaps, cutoff, power=2): super().__init__() - + self.b = scipy.signal.firwin(numtaps, cutoff) self.weight = torch.from_numpy(self.b).float().view(1, 1, -1) self.power = power - + def forward(self, y_true, y_pred): - + assert len(y_true.shape) == 3 and len(y_pred.shape) == 3 - + diff = y_true - y_pred diff_lp = torch.nn.functional.conv1d(diff, self.weight) - + loss = torch.mean(torch.abs(diff_lp ** self.power)) - + return loss, diff_lp - + def get_freqz(self): freq, response = scipy.signal.freqz(self.b) - + return freq, response - - - - - \ No newline at end of file + + + + diff --git a/dnn/torch/osce/models/lavoce.py b/dnn/torch/osce/models/lavoce.py index fcfdc8bfa..e34db9efa 100644 --- a/dnn/torch/osce/models/lavoce.py +++ b/dnn/torch/osce/models/lavoce.py @@ -171,19 +171,21 @@ def flop_count(self, rate=16000, verbose=False): frame_rate = rate / self.FRAME_SIZE # feature net - feature_net_flops = self.feature_net.flop_count(frame_rate) + feature_net_flops = self.feature_net.flop_count(frame_rate / self.upsamp_factor) comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate) feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate) + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate)) + shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate) if verbose: print(f"feature net: {feature_net_flops / 1e6} MFLOPS") print(f"comb filters: {comb_flops / 1e6} MFLOPS") print(f"adaptive conv: {af_flops / 1e6} MFLOPS") print(f"feature transforms: {feature_flops / 1e6} MFLOPS") + print(f"adashape: {shape_flops / 1e6} MFLOPS") - return feature_net_flops + comb_flops + af_flops + feature_flops + return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops def feature_transform(self, f, layer): f = f.permute(0, 2, 1) diff --git a/dnn/torch/osce/silk_16_to_48.py b/dnn/torch/osce/silk_16_to_48.py index e59b6cc84..6ca6893ec 100644 --- a/dnn/torch/osce/silk_16_to_48.py +++ b/dnn/torch/osce/silk_16_to_48.py @@ -12,17 +12,17 @@ if __name__ == "__main__": args = parser.parse_args() - + fs, x = wavfile.read(args.input) # being lazy for now assert fs == 16000 and x.dtype == np.int16 - + x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1) - + upsampler = SilkUpsampler() y = upsampler(x) - + y = y.squeeze().numpy().astype(np.int16) - + wavfile.write(args.output, 48000, y[13:]) \ No newline at end of file diff --git a/dnn/torch/osce/train_vocoder.py b/dnn/torch/osce/train_vocoder.py index 590e6d1a6..c82119ade 100644 --- a/dnn/torch/osce/train_vocoder.py +++ b/dnn/torch/osce/train_vocoder.py @@ -244,7 +244,7 @@ def criterion(x, y): print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters") if hasattr(model, 'flop_count'): - print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS") + print(f"{model.flop_count(16000, verbose=True) / 1e6:5.3f} MFLOPS") if ref is not None: pass diff --git a/dnn/torch/osce/utils/layers/fir.py b/dnn/torch/osce/utils/layers/fir.py index 7eeb3e4e5..8c50624b7 100644 --- a/dnn/torch/osce/utils/layers/fir.py +++ b/dnn/torch/osce/utils/layers/fir.py @@ -8,20 +8,20 @@ class FIR(nn.Module): def __init__(self, numtaps, bands, desired, fs=2): super().__init__() - + if numtaps % 2 == 0: print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}") numtaps += 1 - + a = scipy.signal.firls(numtaps, bands, desired, fs=fs) - + self.weight = torch.from_numpy(a.astype(np.float32)) - + def forward(self, x): num_channels = x.size(1) - + weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0) - + y = F.conv1d(x, weight, groups=num_channels) - + return y \ No newline at end of file diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py index fa7bf3483..788dd9f27 100644 --- a/dnn/torch/osce/utils/layers/td_shaper.py +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -59,12 +59,18 @@ def __init__(self, self.feature_alpha1_f = soft_quant(self.feature_alpha1_f) if self.innovate: - self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) - self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) + self.feature_alpha1b_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2)) + self.feature_alpha1b_t = norm(nn.Conv1d(self.env_dim, frame_size, 2)) + self.feature_alpha1c_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2)) + self.feature_alpha1c_t = norm(nn.Conv1d(self.env_dim, frame_size, 2)) self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2)) self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2)) + if softquant: + self.feature_alpha1b_f = soft_quant(self.feature_alpha1b_f) + self.feature_alpha1c_f = soft_quant(self.feature_alpha1c_f) + def flop_count(self, rate): @@ -73,7 +79,7 @@ def flop_count(self, rate): shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size if self.innovate: - inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size + inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b_f, self.feature_alpha1b_t, self.feature_alpha2b, self.feature_alpha1c_f, self.feature_alpha1c_t, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size else: inno_flops = 0 @@ -127,11 +133,11 @@ def forward(self, x, features, debug=False): alpha = alpha.permute(0, 2, 1) if self.innovate: - inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2) + inno_alpha = F.leaky_relu(self.feature_alpha1b_f(f) + self.feature_alpha1b_t(t), 0.2) inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0]))) inno_alpha = inno_alpha.permute(0, 2, 1) - inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2) + inno_x = F.leaky_relu(self.feature_alpha1c_f(f) + self.feature_alpha1c_t(t), 0.2) inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0]))) inno_x = inno_x.permute(0, 2, 1) diff --git a/dnn/torch/osce/utils/lpcnet_features.py b/dnn/torch/osce/utils/lpcnet_features.py index 3d109fd3c..22f7a437e 100644 --- a/dnn/torch/osce/utils/lpcnet_features.py +++ b/dnn/torch/osce/utils/lpcnet_features.py @@ -3,8 +3,8 @@ import torch import numpy as np -def load_lpcnet_features(feature_file, version=2): - if version == 2: +def load_lpcnet_features(feature_file, version=3): + if version == 2 or version == 3: layout = { 'cepstrum': [0,18], 'periods': [18, 19], @@ -37,7 +37,10 @@ def load_lpcnet_features(feature_file, version=2): ) lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]] - periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long() + if version <= 2: + periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long() + else: + periods = torch.round(torch.clip(256./2**(raw_features[:, layout['periods'][0] : layout['periods'][1]]+1.5), 32, 255)).long() return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}