-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_reinforce.py
91 lines (79 loc) · 2.59 KB
/
run_reinforce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# import score
#
# scorer = score.Scorer()
# print 'scorer loaded'
from models.processor import Processor
from models.leap import LEAPModel
from exp.coverage import config_sutter as config
from utils.data import dump
config = config.get_config()
dir = 'build/'
config.saved_model_file = dir + 'rf_sutter_new_sorted_seq2seq.model'
print(config.saved_model_file.split('/')[-1])
p = Processor(config)
model = LEAPModel(p, config)
# model.do_train()
class Scorer(object):
def __init__(self):
pass
def jaccard(self, s0, s1):
s0 = set(s0)
s1 = set(s1)
intersection = len(s0.intersection(s1))
union = len(s0.union(s1))
score = 0.0 if union == 0 else float(intersection) / union
# print(score)
return score
def predict(self, instances):
rewards = []
for i, instance in enumerate(instances):
score_j = self.jaccard(instance[0], instance[1])
rewards.append(score_j)
return rewards
model.load_params('build/sutter_seq2seq_sorted_seed13_100d_lr0.001_h256.model')
model.do_reinforce(Scorer())
# model.do_eval(training = False, filename = 'mimic_seq2seq.h256.txt', max_batch = 5000)
# model.load_params('../models/resume_seed13_100d_lr0.001_h256.model')
# ret = model.do_generate(data)
#
# from utils.eval import Evaluator
# eva = Evaluator()
# cnt = 0
# truth = []
# sum_jaccard = 0
# for line in open("seq2seq.h256.txt"):
# if cnt % 3 == 1:
# truth = set(line.strip().split("T: ")[1].split(" "))
# if cnt % 3 == 2:
# result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
# jaccard = eva.get_jaccard_k(truth, result)
# sum_jaccard += jaccard
# cnt += 1
#
# print(sum_jaccard * 3 / cnt)
#
# truth_list = []
# prediction_list = []
# for line in open("seq2seq.h256.txt"):
# if cnt % 3 == 1:
# truth = set(line.strip().split("T: ")[1].split(" "))
# truth_list.append(truth)
# if cnt % 3 == 2:
# result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
# prediction_list.append(result)
# cnt += 1
#
# cnt = 0
# results = []
# input = []
# truth = []
# for line in open("seq2seq.h256.txt"):
# if cnt % 3 == 0:
# input = set(line.strip().split("S: ")[1].split(" "))
# if cnt % 3 == 1:
# truth = set(line.strip().split("T: ")[1].split(" "))
# if cnt % 3 == 2:
# result = set(line.strip().split("Gen: ")[1].replace("END", "").strip().split(" "))
# results.append((input, truth, result))
# cnt += 1
# dump(results, "sutter_result_seq2seq_1.30.pkl")