-
Notifications
You must be signed in to change notification settings - Fork 3
/
compute_all_scores.py
159 lines (124 loc) · 6.14 KB
/
compute_all_scores.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import torch
import random
import argparse
import numpy as np
from reader import MultiWOZReader
from types import SimpleNamespace
from utils.utils import load_json, save_json, get_or_create_logger
from evaluator import convert_results_format, MultiWozEvaluator
from train_lm import get_config_without_unknown, LMRunner, BertRunner
from lm_dataset import Lm_Reader, Bert_Reader
logger = get_or_create_logger(__name__)
def compute_avg_len(data):
total_usr_turn = 0
total_sys_turn = 0
total_usr_tokens = 0
total_sys_tokens = 0
max_len = 0
for dial_id in data:
for turn in data[dial_id]:
total_sys_turn += 1
total_usr_turn += 1
total_usr_tokens += len(turn['user'].split())
total_sys_tokens += len(turn['resp_gen'].split())
max_len = max(max_len, len(turn['user'].split()))
max_len = max(max_len, len(turn['resp_gen'].split()))
logger.info('Max len: {}; Avg len: {}; Avg usr len: {}; Avg sys len: {};'.format(max_len, (total_usr_tokens + total_sys_tokens) / (total_sys_turn + total_usr_turn), total_usr_tokens / total_usr_turn, total_sys_tokens / total_sys_turn))
def compute_success_and_inform_rate(args, cfg, data):
reader = MultiWOZReader(cfg, cfg.version)
evaluator = MultiWozEvaluator(reader, args.data_type)
if args.eval_type == 'offline':
bleu, success, match = evaluator.e2e_eval(
data, eval_dial_list=None, add_auxiliary_task=cfg.add_auxiliary_task, add_success_rate=True)
score = 0.5 * (success + match) + bleu
logger.info('Offline Evaluation: match: %2.2f; success: %2.2f; bleu: %2.2f; score: %.2f',
match, success, bleu, score)
elif args.eval_type == 'online':
success, match = evaluator.e2e_eval(
data, eval_dial_list=None, add_auxiliary_task=cfg.add_auxiliary_task, online_eval=True, add_success_rate=True)
logger.info('Online Evaluation: match: %2.2f; success: %2.2f;', match, success)
def compute_gptscore(args, data, lm_ckpt, agent=None):
cfg = get_config_without_unknown()
cfg.backbone = 'gpt2'
cfg.ckpt = lm_ckpt
cfg.version = args.version
cfg.ppl_level = 'bart_score'
cfg.compute_for_single = True
cfg.task = 'ppl'
cfg.device = args.device
setattr(cfg, 'gpt_score_singe_side', args.gpt_score_singe_side)
setattr(cfg, 'agent', agent)
reader = Lm_Reader(cfg)
runner = LMRunner(cfg, reader)
gptscore = runner.evaluation_for_single(data, cfg.gpt_score_normalize)
if cfg.gpt_score_singe_side:
logger.info('Online Evaluation: GPT Score for %s: %f;',cfg.agent, gptscore)
else:
logger.info('Online Evaluation: GPT Score: %f;', gptscore)
def compute_nsp_score(args, data, nsp_ckpt):
cfg = get_config_without_unknown()
cfg.backbone = 'bert-base-uncased'
cfg.ckpt = nsp_ckpt
cfg.version = args.version
cfg.compute_for_single = True
cfg.task = 'nsp'
cfg.device = args.device
reader = Bert_Reader(cfg)
runner = BertRunner(cfg, reader)
nspscore = runner.evaluation_for_single(data)
logger.info('Online Evaluation: NSP Score: %f;', nspscore)
def save_results_with_metrics(args, data):
save_path = args.output_result_path[:-5] + '_with_metrics.json'
save_json(data, save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Argument for evaluation")
parser.add_argument("-data_type", type=str, default="test", choices=["dev", "test"])
parser.add_argument("-eval_type", type=str, default='online', choices=['offline', 'online'])
parser.add_argument("-output_result_path", type=str, required=True)
parser.add_argument("-config_dir", type=str, required=True)
parser.add_argument("-use_inform_success", type=bool, default=True)
parser.add_argument("-use_gptscore", type=bool, default=True)
parser.add_argument("-use_nspscore", type=bool, default=True)
parser.add_argument("-seed", type=int, default=42)
parser.add_argument("-gpt_score_normalize", action='store_true')
parser.add_argument("-gpt_score_singe_side", action='store_true')
parser.add_argument("-test_model_name", type=str, default=None)
parser.add_argument("-lm_ckpt_sys", type=str, default='./bart_score_gpt_lm_model_lr_1e_4_sys_side/ckpt-epoch6', help='sentence score only for system responses')
parser.add_argument("-lm_ckpt_usr", type=str, default='./bart_score_gpt_lm_model_lr_1e_4_usr_side/ckpt-epoch6', help='sentence score only for user utterances')
parser.add_argument("-lm_ckpt", type=str, default='./bart_score_gpt_lm_model_lr_1e_4/ckpt-epoch6', help='sentence score model')
parser.add_argument("-nsp_ckpt", type=str, default='./bert_nsp_model_lr_1e_5_1/ckpt-epoch9', help='session score model')
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
setattr(args, 'device', device)
cfg_path = os.path.join(args.config_dir, "run_config.json")
cfg = SimpleNamespace(**load_json(cfg_path))
setattr(args, 'version', cfg.version)
if args.test_model_name != None:
cfg.model_name = args.test_model_name
original_data = load_json(args.output_result_path)
if args.eval_type == 'online':
data = convert_results_format(original_data)
else:
data = original_data
logger.info('Compute all metrics for {}'.format(args.output_result_path))
compute_avg_len(data)
if args.use_inform_success:
compute_success_and_inform_rate(args, cfg, data)
if args.use_gptscore:
if args.gpt_score_singe_side:
lm_ckpt_sys = args.lm_ckpt_sys
lm_ckpt_usr = args.lm_ckpt_usr
compute_gptscore(args, data, lm_ckpt_sys, agent='sys')
compute_gptscore(args, data, lm_ckpt_usr, agent='usr')
lm_ckpt = args.lm_ckpt
args.gpt_score_singe_side = False
compute_gptscore(args, data, lm_ckpt)
if args.use_nspscore:
nsp_ckpt = args.nsp_ckpt
compute_nsp_score(args, data, nsp_ckpt)
save_results_with_metrics(args, data)