-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
80 lines (68 loc) · 2.67 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch.utils.data import Dataset
import os
import numpy as np
import librosa
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
from torch import dot
from torch.linalg import norm
from sklearn import metrics
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import pandas as pd
import sys
from model import DvectorModel
from test_dataset import TestDataset
def get_eer(test_dataset,test_path,trial_path,device='cuda:1'):
trial_data_list = pd.read_csv(trial_path,names=['positive','file1','file2'],sep=' ')
labels = trial_data_list.positive
file1 = list(trial_data_list.file1)
file2 = list(trial_data_list.file2)
file1_embeddings = list(map(lambda x: test_dataset.get_embedding(os.path.join(test_path,x)),\
file1))
file2_embeddings = list(map(lambda x: test_dataset.get_embedding(os.path.join(test_path,x)),\
file2))
cos_sims = [cosine_similarity(file1_embeddings[i][0], file2_embeddings[i][0]) for i in range(len(file1))]
fpr, tpr, thresholds = metrics.roc_curve(labels, cos_sims, pos_label=1)
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
threshold = interp1d(fpr,thresholds)(eer)
return eer, threshold
def cosine_similarity(a, b):
return dot(a,b) / (norm(a)*norm(b))
# 두 화자의 코사인 유사도 반환.
# 안씀
def score(embedding1, embedding2):
return cosine_similarity(embedding1, embedding2)
# 오디오 파일의 임베딩 구함.
# 안씀
def enroll(checkpoint_path,audio_file):
if wav.shape[0] >= 16000*4:
start = random.randrange(0,wav.shape[0] - frames + 1)
wav = wav[start:start+frames]
else:
# start = random.randrange(0, frames - wav.shape[0] + 1)
# wav = np.append(wav,wav[:start+frames])
wav = np.append(wav,np.zeros(frames - wav.shape[0]))
device = torch.device('cpu')
model = torch.load(checkpoint_path)
model.to(device)
model.eval()
x = torch.FloatTensor(wav).to(device)
x = torch.unsqueeze(x,0)
x, _ = model(x)
return x
if __name__ == '__main__':
train_data_path = '/data/train'
test_data_path = '/data/test'
trial_path = '/data/trials/trials.txt'
device = torch.device('cuda:1' if torch.cuda.is_available() else print('No GPU'))
checkpoint_path = sys.argv[1]
embedding_size = int(sys.argv[2])
model = torch.load(checkpoint_path).to(device)
test_data = TestDataset(test_data_path)
test_data.update_embeddings(model,embedding_size,device)
eer, threshold = get_eer(test_data,test_data_path,trial_path)
print('EER: '+str(eer))
print('Threshold: '+str(threshold))