-
Notifications
You must be signed in to change notification settings - Fork 5
/
run.py
67 lines (60 loc) · 1.96 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# import score
#
# scorer = score.Scorer()
# print 'scorer loaded'
from models.processor import Processor
from models.leap import LEAPModel
from exp.coverage import config_mimic as config
from utils.data import dump
config = config.get_config()
print(config.saved_model_file.split('/')[-1])
p = Processor(config)
model = LEAPModel(p, config)
model.do_train()
# model.load_params('build/mimic_seq2seq__seed13_100d_lr0.001_h256.model_1.31')
# model.do_reinforce(scorer)
# model.do_eval(training = False, filename = 'mimic_seq2seq.h256.txt', max_batch = 5000)
# model.load_params('../models/resume_seed13_100d_lr0.001_h256.model')
# ret = model.do_generate(data)
#
# from utils.eval import Evaluator
# eva = Evaluator()
# cnt = 0
# truth = []
# sum_jaccard = 0
# for line in open("seq2seq.h256.txt"):
# if cnt % 3 == 1:
# truth = set(line.strip().split("T: ")[1].split(" "))
# if cnt % 3 == 2:
# result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
# jaccard = eva.get_jaccard_k(truth, result)
# sum_jaccard += jaccard
# cnt += 1
#
# print(sum_jaccard * 3 / cnt)
#
# truth_list = []
# prediction_list = []
# for line in open("seq2seq.h256.txt"):
# if cnt % 3 == 1:
# truth = set(line.strip().split("T: ")[1].split(" "))
# truth_list.append(truth)
# if cnt % 3 == 2:
# result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
# prediction_list.append(result)
# cnt += 1
#
# cnt = 0
# results = []
# input = []
# truth = []
# for line in open("seq2seq.h256.txt"):
# if cnt % 3 == 0:
# input = set(line.strip().split("S: ")[1].split(" "))
# if cnt % 3 == 1:
# truth = set(line.strip().split("T: ")[1].split(" "))
# if cnt % 3 == 2:
# result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
# results.append((input, truth, result))
# cnt += 1
# dump(results, "sutter_result_seq2seq_1.30.pkl")