-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
59 lines (45 loc) · 2.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python
import os
import logging
import threading
import tensorflow as tf
import custom_init_ops
from evaluation import *
from util import *
from model import *
from config import *
class SupertaggerTrainer(object):
def __init__(self, logdir):
self.logdir = logdir
self.writer = tf.train.SummaryWriter(logdir, flush_secs=20)
def train(self, config, data, params):
with tf.Session() as session, Timer("Training") as timer:
with tf.variable_scope("model", initializer=custom_init_ops.dyer_initializer()):
train_model = SupertaggerModel(config, data, is_training=True)
with tf.variable_scope("model", reuse=True):
dev_model = SupertaggerModel(config, data, is_training=False)
session.run(tf.initialize_all_variables())
with tf.variable_scope("model", reuse=True):
params.assign_pretrained(session)
population_thread = threading.Thread(target=data.populate_train_queue, args=(session, train_model))
population_thread.start()
evaluator = SupertaggerEvaluator(session, data.dev_data, dev_model, train_model.global_step, self.writer, self.logdir)
i = 0
epoch = 0
train_loss = 0.0
# Evaluator tells us if we should stop.
while evaluator.maybe_evaluate():
i += 1
_, loss = session.run([train_model.optimize,
train_model.loss])
train_loss += loss
if i % 100 == 0:
timer.tick("{} training steps".format(i))
if i >= (len(data.train_sentences) + len(data.tritrain_sentences))/data.batch_size:
train_loss = train_loss / i
logging.info("Epoch {} complete(steps={}, loss={:.3f}).".format(epoch, i, train_loss))
self.writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag="Train Loss", simple_value=train_loss)]),
tf.train.global_step(session, train_model.global_step))
i = 0
epoch += 1
train_loss = 0.0