-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathEWC_train.py
131 lines (109 loc) · 4.6 KB
/
EWC_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from utils.config import *
from models.TRADE import *
from torch import autograd
from copy import deepcopy
import pickle
import os.path
#### LOAD MODEL path
except_domain = args['except_domain']
directory = args['path'].split("/")
HDD = directory[2].split('HDD')[1].split('BSZ')[0]
# decoder = directory[1].split('-')[0]
BSZ = int(args['batch']) if args['batch'] else int(directory[2].split('BSZ')[1].split('DR')[0])
args["decoder"] = "TRADE"
args["HDD"] = HDD
if args['dataset']=='multiwoz':
from utils.utils_multiWOZ_DST import *
else:
print("You need to provide the --dataset information")
filename_fisher = args['path']+"fisher{}".format(args["fisher_sample"])
if(os.path.isfile(filename_fisher) ):
print("Load Fisher Matrix" + filename_fisher)
[fisher,optpar] = pickle.load(open(filename_fisher,'rb'))
else:
train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = prepare_data_seq(True, args['task'], False, batch_size=1)
model = globals()[args["decoder"]](
int(HDD),
lang=lang,
path=args['path'],
task=args["task"],
lr=args["learn"],
dropout=args["drop"],
slots=SLOTS_LIST,
gating_dict=gating_dict)
print("Computing Fisher Matrix ")
fisher = {}
optpar = {}
for n, p in model.named_parameters():
optpar[n] = torch.Tensor(p.cpu().data).cuda()
p.data.zero_()
fisher[n] = torch.Tensor(p.cpu().data).cuda()
pbar = tqdm(enumerate(train),total=len(train))
for i, data_o in pbar:
model.train_batch(data_o, int(args['clip']), SLOTS_LIST[1], reset=(i==0))
model.loss_ptr_to_bp.backward()
for n, p in model.named_parameters():
if p.grad is not None:
fisher[n].data += p.grad.data ** 2
if(i == args["fisher_sample"]):break
for name_f,_ in fisher.items():#range(len(fisher)):
fisher[name_f] /= args["fisher_sample"] #len(train)
print("Saving Fisher Matrix in ", filename_fisher)
pickle.dump([fisher,optpar],open(filename_fisher,'wb'))
exit(0)
### LOAD DATA
train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = prepare_data_seq(True, args['task'], False, batch_size=BSZ)
args['only_domain'] = except_domain
args['except_domain'] = ''
args["fisher_sample"] = 0
args["data_ratio"] = 1
train_single, dev_single, test_single, _, _, SLOTS_LIST_single, _, _ = prepare_data_seq(True, args['task'], False, batch_size=BSZ)
args['except_domain'] = except_domain
#### LOAD MODEL
model = globals()[args["decoder"]](
int(HDD),
lang=lang,
path=args['path'],
task=args["task"],
lr=args["learn"],
dropout=args["drop"],
slots=SLOTS_LIST,
gating_dict=gating_dict)
avg_best, cnt, acc = 0.0, 0, 0.0
weights_best = deepcopy(model.state_dict())
try:
for epoch in range(100):
print("Epoch:{}".format(epoch))
# Run the train function
pbar = tqdm(enumerate(train_single),total=len(train_single))
for i, data in pbar:
model.train_batch(data, int(args['clip']), SLOTS_LIST_single[1], reset=(i==0))
### EWC loss
for i, (name,p) in enumerate(model.named_parameters()):
if p.grad is not None:
l = args['lambda_ewc'] * fisher[name].cuda() * (p - optpar[name].cuda()).pow(2)
model.loss_grad += l.sum()
model.optimize(args['clip'])
pbar.set_description(model.print_loss())
if((epoch+1) % int(args['evalp']) == 0):
acc = model.evaluate(dev_single, avg_best, SLOTS_LIST_single[2], args["earlyStop"])
model.scheduler.step(acc)
if(acc >= avg_best):
avg_best = acc
cnt=0
weights_best = deepcopy(model.state_dict())
else:
cnt+=1
if(cnt == 6 or (acc==1.0 and args["earlyStop"]==None)):
print("Ran out of patient, early stop...")
break
except KeyboardInterrupt:
pass
model.load_state_dict({ name: weights_best[name] for name in weights_best })
model.eval()
# After Fine tuning...
print("[Info] After Fine Tune ...")
print("[Info] Test Set on 4 domains...")
acc_test_4d = model.evaluate(test_special, 1e7, SLOTS_LIST[2])
print("[Info] Test Set on 1 domain {} ...".format(except_domain))
acc_test = model.evaluate(test_single, 1e7, SLOTS_LIST[3])