forked from dgaddy/silent_speech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalign.py
38 lines (31 loc) · 1008 Bytes
/
align.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
import scipy
import matplotlib.pyplot as plt
from numba import jit
import torch
@jit
def time_warp(costs):
dtw = np.zeros_like(costs)
dtw[0,1:] = np.inf
dtw[1:,0] = np.inf
eps = 1e-4
for i in range(1,costs.shape[0]):
for j in range(1,costs.shape[1]):
dtw[i,j] = costs[i,j] + min(dtw[i-1,j],dtw[i,j-1],dtw[i-1,j-1])
return dtw
def align_from_distances(distance_matrix, debug=False):
# for each position in spectrum 1, returns best match position in spectrum2
# using monotonic alignment
dtw = time_warp(distance_matrix)
i = distance_matrix.shape[0]-1
j = distance_matrix.shape[1]-1
results = [0] * distance_matrix.shape[0]
while i > 0 and j > 0:
results[i] = j
i, j = min([(i-1,j),(i,j-1),(i-1,j-1)], key=lambda x: dtw[x[0],x[1]])
if debug:
visual = np.zeros_like(dtw)
visual[range(len(results)),results] = 1
plt.matshow(visual)
plt.show()
return results