Skip to content

Commit

Permalink
fixed some things that got broken
Browse files Browse the repository at this point in the history
  • Loading branch information
janpbuethe committed Sep 15, 2024
1 parent ff6dea5 commit 4fe9cde
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 36 deletions.
30 changes: 30 additions & 0 deletions dnn/torch/osce/data/lpcnet_vocoding_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(self,
self.getitem = self.getitem_v1
elif self.version == 2:
self.getitem = self.getitem_v2
elif self.version == 3:
self.getitem = self.getitem_v3
else:
raise ValueError(f"dataset version {self.version} unknown")

Expand Down Expand Up @@ -125,6 +127,34 @@ def __init__(self,
def __getitem__(self, index):
return self.getitem(index)

def getitem_v3(self, index):
sample = dict()

# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead

for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]

# convert periods
if 'periods' in self.input_features:
sample['periods'] = np.round(np.clip(256./2**(sample['periods']+1.5), 32, 255)).astype('int16')

signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length

sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]

# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
target = torch.FloatTensor(sample[self.target]) / 2**15
periods = torch.LongTensor(sample['periods'])

return {'features' : features, 'periods' : periods, 'target' : target}

def getitem_v2(self, index):
sample = dict()

Expand Down
25 changes: 12 additions & 13 deletions dnn/torch/osce/losses/td_lowpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,27 @@
class TDLowpass(torch.nn.Module):
def __init__(self, numtaps, cutoff, power=2):
super().__init__()

self.b = scipy.signal.firwin(numtaps, cutoff)
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
self.power = power

def forward(self, y_true, y_pred):

assert len(y_true.shape) == 3 and len(y_pred.shape) == 3

diff = y_true - y_pred
diff_lp = torch.nn.functional.conv1d(diff, self.weight)

loss = torch.mean(torch.abs(diff_lp ** self.power))

return loss, diff_lp

def get_freqz(self):
freq, response = scipy.signal.freqz(self.b)

return freq, response









6 changes: 4 additions & 2 deletions dnn/torch/osce/models/lavoce.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,21 @@ def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE

# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
feature_net_flops = self.feature_net.flop_count(frame_rate / self.upsamp_factor)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate)

if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
print(f"adashape: {shape_flops / 1e6} MFLOPS")

return feature_net_flops + comb_flops + af_flops + feature_flops
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops

def feature_transform(self, f, layer):
f = f.permute(0, 2, 1)
Expand Down
10 changes: 5 additions & 5 deletions dnn/torch/osce/silk_16_to_48.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@

if __name__ == "__main__":
args = parser.parse_args()

fs, x = wavfile.read(args.input)

# being lazy for now
assert fs == 16000 and x.dtype == np.int16

x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)

upsampler = SilkUpsampler()
y = upsampler(x)

y = y.squeeze().numpy().astype(np.int16)

wavfile.write(args.output, 48000, y[13:])
2 changes: 1 addition & 1 deletion dnn/torch/osce/train_vocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def criterion(x, y):

print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"{model.flop_count(16000, verbose=True) / 1e6:5.3f} MFLOPS")

if ref is not None:
pass
Expand Down
14 changes: 7 additions & 7 deletions dnn/torch/osce/utils/layers/fir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@
class FIR(nn.Module):
def __init__(self, numtaps, bands, desired, fs=2):
super().__init__()

if numtaps % 2 == 0:
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
numtaps += 1

a = scipy.signal.firls(numtaps, bands, desired, fs=fs)

self.weight = torch.from_numpy(a.astype(np.float32))

def forward(self, x):
num_channels = x.size(1)

weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)

y = F.conv1d(x, weight, groups=num_channels)

return y
16 changes: 11 additions & 5 deletions dnn/torch/osce/utils/layers/td_shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ def __init__(self,
self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)

if self.innovate:
self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
self.feature_alpha1b_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
self.feature_alpha1b_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
self.feature_alpha1c_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
self.feature_alpha1c_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))

self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))

if softquant:
self.feature_alpha1b_f = soft_quant(self.feature_alpha1b_f)
self.feature_alpha1c_f = soft_quant(self.feature_alpha1c_f)


def flop_count(self, rate):

Expand All @@ -73,7 +79,7 @@ def flop_count(self, rate):
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size

if self.innovate:
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b_f, self.feature_alpha1b_t, self.feature_alpha2b, self.feature_alpha1c_f, self.feature_alpha1c_t, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
else:
inno_flops = 0

Expand Down Expand Up @@ -127,11 +133,11 @@ def forward(self, x, features, debug=False):
alpha = alpha.permute(0, 2, 1)

if self.innovate:
inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2)
inno_alpha = F.leaky_relu(self.feature_alpha1b_f(f) + self.feature_alpha1b_t(t), 0.2)
inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0])))
inno_alpha = inno_alpha.permute(0, 2, 1)

inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2)
inno_x = F.leaky_relu(self.feature_alpha1c_f(f) + self.feature_alpha1c_t(t), 0.2)
inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0])))
inno_x = inno_x.permute(0, 2, 1)

Expand Down
9 changes: 6 additions & 3 deletions dnn/torch/osce/utils/lpcnet_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import numpy as np

def load_lpcnet_features(feature_file, version=2):
if version == 2:
def load_lpcnet_features(feature_file, version=3):
if version == 2 or version == 3:
layout = {
'cepstrum': [0,18],
'periods': [18, 19],
Expand Down Expand Up @@ -37,7 +37,10 @@ def load_lpcnet_features(feature_file, version=2):
)

lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
if version <= 2:
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
else:
periods = torch.round(torch.clip(256./2**(raw_features[:, layout['periods'][0] : layout['periods'][1]]+1.5), 32, 255)).long()

return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}

Expand Down

0 comments on commit 4fe9cde

Please sign in to comment.