-
Notifications
You must be signed in to change notification settings - Fork 4
/
test.py
42 lines (35 loc) · 1004 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
from model import MusicModel
from config import get_config
from dataloader import get_loader
from tqdm import tqdm
def test(model_path,test_path,bs):
model = MusicModel()
model = torch.load(model_path)
test_loader = get_loader(test_path,bs)
test_loader = iter(test_loader)
TP = 0
FN = 0
FP = 0
TN = 0
desc = ' - (Testing) - '
for (data,label) in tqdm(test_loader,desc=desc,ncols=80):
result = float(model(data).squeeze(-1).squeeze(-1))
label = int(label[0])
if label==1:
if result >= 0.5:
TP += 1
else:
FN += 1
else:
if result >= 0.5:
FP += 1
else:
TN += 1
acc = float(TP+TN)/float(TP+FN+FP+TN)
acc = round(acc*100,2)
print('ACC:'+str(acc))
if __name__ == "__main__":
config = get_config()
test(config.save_path,config.test_path,config.test_batch_size)