-
Notifications
You must be signed in to change notification settings - Fork 6
/
test.py
186 lines (142 loc) · 6.84 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import datetime
import logging
import os
import pdb
import sys
import numpy as np
import torch
from transformers import BertTokenizer
from sklearn.metrics import classification_report, recall_score, f1_score, precision_score
from torch import nn
from torch.utils.data import DataLoader
from model import (KvretConfig, KvretDataset, MTSIAdapterDataset, MTSIBert,
MTSIKvretConfig, TwoSepTensorBuilder)
def get_eos(turns, win_size, windows_per_dialogue):
res = torch.zeros((len(turns), windows_per_dialogue), dtype=torch.long)
user_count = 0
for idx, curr_dial in enumerate(turns):
for t in curr_dial:
if t == 1:
user_count += 1
res[idx][user_count-1] = 1
return res, user_count-1
def remove_dataparallel(load_checkpoint_path):
# original saved file with DataParallel
state_dict = torch.load(load_checkpoint_path)
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
return new_state_dict
def compute_f1(model, data_generator, device):
# initializes statistics
true_eos = []
pred_eos = []
true_action = []
pred_action = []
true_intent = []
pred_intent = []
tensor_builder = TwoSepTensorBuilder()
with torch.no_grad():
for local_batch, local_turns, local_intents, local_actions, dialogue_ids in data_generator:
# 0 = intra dialogue ; 1 = eos
eos_label, eos_idx = get_eos(local_turns, MTSIKvretConfig._WINDOW_SIZE,\
windows_per_dialogue=KvretConfig._KVRET_MAX_USER_SENTENCES_PER_TRAIN_DIALOGUE + 2)
# local_batch.shape == B x D_LEN x U_LEN
# local_intents.shape == B
# local_actions.shape == B
# local_eos_label.shape == B x D_PER_WIN
local_batch = local_batch.to(device)
local_intents = local_intents.to(device)
local_actions = local_actions.to(device)
eos_label = eos_label.to(device)
eos, intent, action = model(local_batch,
local_turns,
dialogue_ids,
tensor_builder,
device)
# take the predicted label
eos_predicted = torch.argmax(eos['prediction'], dim=-1)
action_predicted = torch.argmax(action['prediction'], dim=-1)
intent_predicted = torch.argmax(intent['prediction'], dim=-1)
true_eos += eos_label[0][:eos_idx+1].tolist()
pred_eos += eos_predicted.tolist()
true_action += local_actions.tolist()
pred_action.append(action_predicted.item())
true_intent += local_intents.tolist()
pred_intent.append(intent_predicted.item())
print('macro scores:')
print('--EOS score:')
#print(classification_report(true_eos, pred_eos, target_names=['NON-EOS', 'EOS']))
print('precision: '+str(precision_score(true_eos, pred_eos, average='macro')))
print('recall: '+str(recall_score(true_eos, pred_eos, average='macro')))
print('f1: '+str(f1_score(true_eos, pred_eos, average='macro')))
print('--Action score:')
#print(classification_report(true_action, pred_action, target_names=['FETCH', 'INSERT']))
print('precision: '+str(precision_score(true_action, pred_action, average='macro')))
print('recall: '+str(recall_score(true_action, pred_action, average='macro')))
print('f1: '+str(f1_score(true_action, pred_action, average='macro')))
print('--Intent score:')
#print(classification_report(true_intent, pred_intent, target_names=['SCHEDULE', 'WEATHER', 'NAVIGATE']))
print('precision: '+str(precision_score(true_intent, pred_intent, average='micro')))
print('recall: '+str(recall_score(true_intent, pred_intent, average='micro')))
print('f1: '+str(f1_score(true_intent, pred_intent, average='micro')))
def test(load_checkpoint_path):
"""
Test utility
"""
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print('active device = '+str(device))
# Dataset preparation
# Bert adapter for dataset
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case = False)
# pass max_len + 1 (to pad of 1 also the longest sentence, a sort of EOS) + 1 (random last sentence from other)
# Model preparation
model = MTSIBert(num_layers_encoder = MTSIKvretConfig._ENCODER_LAYERS_NUM,
num_layers_eos = MTSIKvretConfig._EOS_LAYERS_NUM,
n_intents = MTSIKvretConfig._N_INTENTS,
batch_size = MTSIKvretConfig._BATCH_SIZE,
pretrained = 'bert-base-cased',
seed = MTSIKvretConfig._SEED,
window_size = MTSIKvretConfig._WINDOW_SIZE)
# work on multiple GPUs when availables
if torch.cuda.device_count() > 1:
print('active devices = '+str(torch.cuda.device_count()))
model = nn.DataParallel(model)
print('model loaded from: '+load_checkpoint_path)
model.load_state_dict(torch.load(load_checkpoint_path))
#new_state_dict = remove_dataparallel(load_checkpoint_path)
#model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
# Parameters
params = {'batch_size': MTSIKvretConfig._BATCH_SIZE,
'shuffle': False,
'num_workers': 0}
# f1-score on test set
test_set = KvretDataset(KvretConfig._KVRET_TEST_PATH)
test_set.remove_subsequent_actor_utterances()
badapter_test = MTSIAdapterDataset(test_set,
tokenizer,
KvretConfig._KVRET_MAX_BERT_TOKENS_PER_TRAIN_SENTENCE + 1,
KvretConfig._KVRET_MAX_BERT_SENTENCES_PER_TRAIN_DIALOGUE+2)
test_generator = DataLoader(badapter_test, **params)
print('### TEST SET:')
compute_f1(model, test_generator, device)
# f1-score on validation set
val_set = KvretDataset(KvretConfig._KVRET_VAL_PATH)
val_set.remove_subsequent_actor_utterances()
badapter_val = MTSIAdapterDataset(val_set,
tokenizer,
KvretConfig._KVRET_MAX_BERT_TOKENS_PER_TRAIN_SENTENCE + 1,
KvretConfig._KVRET_MAX_BERT_SENTENCES_PER_TRAIN_DIALOGUE+2)
val_generator = DataLoader(badapter_val, **params)
print('### VALIDATION SET:')
compute_f1(model, val_generator, device)
if __name__ == '__main__':
test(load_checkpoint_path='dict_archive/MINI_BATCH16/100epochs/deep/state_dict.pt')