-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmyTrain.py
69 lines (56 loc) · 2.05 KB
/
myTrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from tqdm import tqdm
import torch.nn as nn
from utils.config import *
from utils.utils_multiWOZ_DST import *
from models.TRADE import *
# from utils.utils_multiWOZ_DST_Bert import *
# from models.BERT import *
'''
python myTrain.py -dec= -bsz= -hdd= -dr= -lr=
'''
early_stop = args['earlyStop']
if args['dataset'] == 'multiwoz':
from utils.utils_multiWOZ_DST import *
early_stop = None
else:
print("You need to provide the --dataset information")
exit(1)
# Configure models and load data
avg_best, cnt, acc = 0.0, 0, 0.0
train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = prepare_data_seq(True, args['task'], False,
batch_size=int(
args['batch']),use_bert=True)
model = globals()[args['decoder']](
hidden_size=int(args['hidden']),
lang=lang,
path=args['path'],
task=args['task'],
lr=float(args['learn']),
dropout=float(args['drop']),
slots=SLOTS_LIST,
gating_dict=gating_dict,
nb_train_vocab=max_word)
# print("[Info] Slots include ", SLOTS_LIST)
# print("[Info] Unpointable Slots include ", gating_dict)
for epoch in range(200):
print("Epoch:{}".format(epoch))
# Run the train function
pbar = tqdm(enumerate(train), total=len(train))
for i, data in pbar:
model.train_batch(data, int(args['clip']), SLOTS_LIST[1], reset=(i == 0))
model.optimize(args['clip'])
pbar.set_description(model.print_loss())
# print(data)
# exit(1)
if ((epoch + 1) % int(args['evalp']) == 0):
acc = model.evaluate(dev, avg_best, SLOTS_LIST[2], early_stop)
model.scheduler.step(acc)
if (acc >= avg_best):
avg_best = acc
cnt = 0
best_model = model
else:
cnt += 1
if (cnt == args["patience"] or (acc == 1.0 and early_stop == None)):
print("Ran out of patient, early stop...")
break