diff --git a/dnn/torch/osce/data/simple_bwe_dataset.py b/dnn/torch/osce/data/simple_bwe_dataset.py index 45b5e5b91..27dd2d382 100644 --- a/dnn/torch/osce/data/simple_bwe_dataset.py +++ b/dnn/torch/osce/data/simple_bwe_dataset.py @@ -41,10 +41,13 @@ def __init__(self, path, frames_per_sample=100, spec_num_bands=32, - max_instafreq_bin=40 + max_instafreq_bin=40, + upsampling_delay48=13, ): self.frames_per_sample = frames_per_sample + self.upsampling_delay48 = upsampling_delay48 + self.signal_16k = np.fromfile(os.path.join(path, 'signal_16kHz.s16'), dtype=np.int16) self.signal_48k = np.fromfile(os.path.join(path, 'signal_48kHz.s16'), dtype=np.int16) @@ -53,7 +56,7 @@ def __init__(self, self.create_features = bwe_feature_factory(spec_num_bands=spec_num_bands, max_instafreq_bin=max_instafreq_bin) - self.frame_offset = 4 + self.frame_offset = 6 self.len = (num_frames - self.frame_offset) // frames_per_sample @@ -71,7 +74,8 @@ def __getitem__(self, index): x_16 = self.signal_16k[signal_start16 : signal_stop16].astype(np.float32) / 2**15 history_16 = self.signal_16k[signal_start16 - 320 : signal_start16].astype(np.float32) / 2**15 - x_48 = self.signal_48k[3 * signal_start16 : 3 * signal_stop16].astype(np.float32) / 2**15 + x_48 = self.signal_48k[3 * signal_start16 - self.upsampling_delay48 + : 3 * signal_stop16 - self.upsampling_delay48].astype(np.float32) / 2**15 features = self.create_features( x_16, diff --git a/dnn/torch/osce/engine/bwe_engine.py b/dnn/torch/osce/engine/bwe_engine.py index 3d9185958..64fbacd84 100644 --- a/dnn/torch/osce/engine/bwe_engine.py +++ b/dnn/torch/osce/engine/bwe_engine.py @@ -23,12 +23,14 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, batch[key] = batch[key].to(device) target = batch['x_48'] + x16 = batch['x_16'] + x_up = model.upsampler(x16.unsqueeze(1)) # calculate model output output = model(batch['x_16'].unsqueeze(1), batch['features']) # calculate loss - loss = criterion(target, output.squeeze(1)) + loss = criterion(target, output.squeeze(1), x_up) # calculate gradients loss.backward() @@ -79,7 +81,7 @@ def evaluate(model, criterion, dataloader, device, log_interval=10): output = model(batch['x_16'].unsqueeze(1), batch['features']) # calculate loss - loss = criterion(target, output.squeeze(1)) + loss = criterion(target, output.squeeze(1), model.upsampler(batch['x_16'].unsqueeze(1))) # update running loss running_loss += float(loss.cpu()) diff --git a/dnn/torch/osce/losses/td_lowpass.py b/dnn/torch/osce/losses/td_lowpass.py index c3e2b11fc..b9fb2d496 100644 --- a/dnn/torch/osce/losses/td_lowpass.py +++ b/dnn/torch/osce/losses/td_lowpass.py @@ -13,15 +13,16 @@ def __init__(self, numtaps, cutoff, power=2): self.power = power def forward(self, y_true, y_pred): - - assert len(y_true.shape) == 3 and len(y_pred.shape) == 3 + + if len(y_true.shape) < 3: y_true = y_true.unsqueeze(1) + if len(y_pred.shape) < 3: y_pred = y_pred.unsqueeze(1) diff = y_true - y_pred diff_lp = torch.nn.functional.conv1d(diff, self.weight) loss = torch.mean(torch.abs(diff_lp ** self.power)) - return loss, diff_lp + return loss def get_freqz(self): freq, response = scipy.signal.freqz(self.b) diff --git a/dnn/torch/osce/train_bwe_model.py b/dnn/torch/osce/train_bwe_model.py index 491f587d1..3803b4d0c 100644 --- a/dnn/torch/osce/train_bwe_model.py +++ b/dnn/torch/osce/train_bwe_model.py @@ -63,6 +63,7 @@ from utils.misc import count_parameters, count_nonzero_parameters from losses.stft_loss import MRSTFTLoss, MRLogMelLoss +from losses.td_lowpass import TDLowpass parser = argparse.ArgumentParser() @@ -204,8 +205,9 @@ w_xcorr = setup['training']['loss']['w_xcorr'] w_sxcorr = setup['training']['loss']['w_sxcorr'] w_l2 = setup['training']['loss']['w_l2'] +w_tdlp = setup['training']['loss'].get('w_tdlp', 0) -w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 +w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 + w_tdlp fft_sizes_16k = [2048, 1024, 512, 256, 128, 64] @@ -233,10 +235,12 @@ def td_l1(y_true, y_pred, pow=0): return torch.mean(tmp) -def criterion(x, y): +tdlp = TDLowpass(15, 4000/24000).to(device) + +def criterion(x, y, x_up): return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) - + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum + + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + tdlp(x_up, y)) / w_sum diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py index f1d2c0148..34c06b968 100644 --- a/dnn/torch/osce/utils/templates.py +++ b/dnn/torch/osce/utils/templates.py @@ -90,7 +90,7 @@ bwenet_setup = { - 'dataset': '/local2/bwe0_dataset//training', + 'dataset': '/local2/bwe0_dataset/training', 'validation_dataset': '/local2/bwe0_dataset/validation', 'model': { 'name': 'bwenet', @@ -107,7 +107,8 @@ 'data': { 'frames_per_sample': 100, 'spec_num_bands' : 32, - 'max_instafreq_bin' : 40 + 'max_instafreq_bin' : 40, + 'upsampling_delay48' : 13 }, 'training': { 'batch_size': 128, @@ -122,8 +123,9 @@ 'w_wsc': 0, 'w_xcorr': 0, 'w_sxcorr': 1, - 'w_l2': 10, - 'w_slm': 2 + 'w_l2': 0, + 'w_slm': 2, + 'w_tdlp': 1 } } }