-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_eval_sutter_sorted.py
66 lines (61 loc) · 2.02 KB
/
run_eval_sutter_sorted.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from models.processor import Processor
from models.leap import LEAPModel
from exp.coverage import config_mimic as config
from utils.data import dump
config = config.get_config()
print(config.saved_model_file.split('/')[-1])
p = Processor(config)
model = LEAPModel(p, config)
# model.do_train()
model.load_params('build/sutter_seq2seq_sorted_seed13_100d_lr0.001_h256.model')
# model.do_reinforce(scorer)
model.do_eval(training = False, filename = 'sutter_sorted_seq2seq.h256.txt', max_batch = 5000)
# model.load_params('../models/resume_seed13_100d_lr0.001_h256.model')
# ret = model.do_generate(data)
#
# from utils.eval import Evaluator
# eva = Evaluator()
# cnt = 0
# truth = []
# sum_jaccard = 0
# for line in open("seq2seq.h256.txt"):
# if cnt % 3 == 1:
# truth = set(line.strip().split("T: ")[1].split(" "))
# if cnt % 3 == 2:
# result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
# jaccard = eva.get_jaccard_k(truth, result)
# sum_jaccard += jaccard
# cnt += 1
#
# print(sum_jaccard * 3 / cnt)
#
# cnt = 0
# truth_list = []
# prediction_list = []
# for line in open("seq2seq.h256.txt"):
# if cnt % 3 == 1:
# truth = set(line.strip().split("T: ")[1].split(" "))
# truth_list.append(truth)
# if cnt % 3 == 2:
# result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
# prediction_list.append(result)
# cnt += 1
#
cnt = 0
results = []
input = []
truth = []
for line in open("sutter_sorted_seq2seq.h256.txt"):
if cnt % 3 == 0:
input = set(line.strip().split("S: ")[1].split(" "))
if cnt % 3 == 1:
if len(line.strip().split("T: ")) <= 1:
truth = []
continue
truth = set(line.strip().split("T: ")[1].split(" "))
if cnt % 3 == 2:
result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
if len(truth) > 0:
results.append((input, truth, result))
cnt += 1
dump(results, "sutter_sorted_result_seq2seq.pkl")