Skip to content

Commit

Permalink
fixed alignment issue and added time-domain lowpass loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Buethe committed Apr 26, 2024
1 parent 479795e commit 6b3061e
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 15 deletions.
10 changes: 7 additions & 3 deletions dnn/torch/osce/data/simple_bwe_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@ def __init__(self,
path,
frames_per_sample=100,
spec_num_bands=32,
max_instafreq_bin=40
max_instafreq_bin=40,
upsampling_delay48=13,
):

self.frames_per_sample = frames_per_sample
self.upsampling_delay48 = upsampling_delay48

self.signal_16k = np.fromfile(os.path.join(path, 'signal_16kHz.s16'), dtype=np.int16)
self.signal_48k = np.fromfile(os.path.join(path, 'signal_48kHz.s16'), dtype=np.int16)

Expand All @@ -53,7 +56,7 @@ def __init__(self,

self.create_features = bwe_feature_factory(spec_num_bands=spec_num_bands, max_instafreq_bin=max_instafreq_bin)

self.frame_offset = 4
self.frame_offset = 6

self.len = (num_frames - self.frame_offset) // frames_per_sample

Expand All @@ -71,7 +74,8 @@ def __getitem__(self, index):
x_16 = self.signal_16k[signal_start16 : signal_stop16].astype(np.float32) / 2**15
history_16 = self.signal_16k[signal_start16 - 320 : signal_start16].astype(np.float32) / 2**15

x_48 = self.signal_48k[3 * signal_start16 : 3 * signal_stop16].astype(np.float32) / 2**15
x_48 = self.signal_48k[3 * signal_start16 - self.upsampling_delay48
: 3 * signal_stop16 - self.upsampling_delay48].astype(np.float32) / 2**15

features = self.create_features(
x_16,
Expand Down
6 changes: 4 additions & 2 deletions dnn/torch/osce/engine/bwe_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler,
batch[key] = batch[key].to(device)

target = batch['x_48']
x16 = batch['x_16']
x_up = model.upsampler(x16.unsqueeze(1))

# calculate model output
output = model(batch['x_16'].unsqueeze(1), batch['features'])

# calculate loss
loss = criterion(target, output.squeeze(1))
loss = criterion(target, output.squeeze(1), x_up)

# calculate gradients
loss.backward()
Expand Down Expand Up @@ -79,7 +81,7 @@ def evaluate(model, criterion, dataloader, device, log_interval=10):
output = model(batch['x_16'].unsqueeze(1), batch['features'])

# calculate loss
loss = criterion(target, output.squeeze(1))
loss = criterion(target, output.squeeze(1), model.upsampler(batch['x_16'].unsqueeze(1)))

# update running loss
running_loss += float(loss.cpu())
Expand Down
7 changes: 4 additions & 3 deletions dnn/torch/osce/losses/td_lowpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ def __init__(self, numtaps, cutoff, power=2):
self.power = power

def forward(self, y_true, y_pred):

assert len(y_true.shape) == 3 and len(y_pred.shape) == 3

if len(y_true.shape) < 3: y_true = y_true.unsqueeze(1)
if len(y_pred.shape) < 3: y_pred = y_pred.unsqueeze(1)

diff = y_true - y_pred
diff_lp = torch.nn.functional.conv1d(diff, self.weight)

loss = torch.mean(torch.abs(diff_lp ** self.power))

return loss, diff_lp
return loss

def get_freqz(self):
freq, response = scipy.signal.freqz(self.b)
Expand Down
10 changes: 7 additions & 3 deletions dnn/torch/osce/train_bwe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from utils.misc import count_parameters, count_nonzero_parameters

from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
from losses.td_lowpass import TDLowpass


parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -204,8 +205,9 @@
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_tdlp = setup['training']['loss'].get('w_tdlp', 0)

w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 + w_tdlp


fft_sizes_16k = [2048, 1024, 512, 256, 128, 64]
Expand Down Expand Up @@ -233,10 +235,12 @@ def td_l1(y_true, y_pred, pow=0):

return torch.mean(tmp)

def criterion(x, y):
tdlp = TDLowpass(15, 4000/24000).to(device)

def criterion(x, y, x_up):

return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + tdlp(x_up, y)) / w_sum



Expand Down
10 changes: 6 additions & 4 deletions dnn/torch/osce/utils/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@


bwenet_setup = {
'dataset': '/local2/bwe0_dataset//training',
'dataset': '/local2/bwe0_dataset/training',
'validation_dataset': '/local2/bwe0_dataset/validation',
'model': {
'name': 'bwenet',
Expand All @@ -107,7 +107,8 @@
'data': {
'frames_per_sample': 100,
'spec_num_bands' : 32,
'max_instafreq_bin' : 40
'max_instafreq_bin' : 40,
'upsampling_delay48' : 13
},
'training': {
'batch_size': 128,
Expand All @@ -122,8 +123,9 @@
'w_wsc': 0,
'w_xcorr': 0,
'w_sxcorr': 1,
'w_l2': 10,
'w_slm': 2
'w_l2': 0,
'w_slm': 2,
'w_tdlp': 1
}
}
}
Expand Down

0 comments on commit 6b3061e

Please sign in to comment.