forked from wadhwasahil/Relation_Extraction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
117 lines (104 loc) · 5.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import CNN
from text_cnn import TextCNN
import data_helpers
import os
import numpy as np
import time
import tensorflow as tf
import datetime
with tf.Graph().as_default():
start_time = time.time()
session_conf = tf.ConfigProto(allow_soft_placement=CNN.FLAGS.allow_soft_placement,
log_device_placement=CNN.FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
cnn = TextCNN(filter_sizes=list(map(int, CNN.FLAGS.filter_sizes.split(","))),
num_filters=CNN.FLAGS.num_filters, vec_shape=(CNN.FLAGS.sequence_length, CNN.FLAGS.embedding_size * CNN.FLAGS.window_size + 2 * CNN.FLAGS.distance_dim),
l2_reg_lambda=CNN.FLAGS.l2_reg_lambda)
# Define Training procedure
global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.AdamOptimizer(1e-3)
grads_and_vars = optimizer.compute_gradients(cnn.loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
# Keep track of gradient values and sparsity (optional)
grad_summaries = []
for g, v in grads_and_vars:
if g is not None:
grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
grad_summaries.append(grad_hist_summary)
grad_summaries.append(sparsity_summary)
grad_summaries_merged = tf.merge_summary(grad_summaries)
# Output directory for models and summaries
timestamp = str(int(time.time()))
out_dir = os.path.abspath(os.path.join(os.path.curdir, "data", timestamp))
print("Writing to {}\n".format(out_dir))
# Summaries for loss and accuracy
loss_summary = tf.scalar_summary("loss", cnn.loss)
acc_summary = tf.scalar_summary("accuracy", cnn.accuracy)
# Train Summaries
train_summary_op = tf.constant(1)
train_summary_op = tf.merge_summary([loss_summary, acc_summary, grad_summaries_merged])
train_summary_dir = os.path.join(out_dir, "summaries", "train")
train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph)
# Dev summaries
dev_summary_op = tf.merge_summary([loss_summary, acc_summary])
dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph)
# Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver = tf.train.Saver(tf.all_variables())
# Initialize all variables
sess.run(tf.global_variables_initializer())
def train_step(x_text_train, y_batch):
feed_dict = {
cnn.input_x: x_text_train,
cnn.input_y: y_batch,
cnn.dropout_keep_prob: CNN.FLAGS.dropout_keep_prob
}
_, step, summaries, loss, accuracy, scores = sess.run(
[train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy, cnn.scores],
feed_dict)
time_str = datetime.datetime.now().isoformat()
print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
train_summary_writer.add_summary(summaries, step)
return loss
def dev_step(x_text_dev, y_batch, writer=None):
"""
Evaluates model on a dev set
"""
feed_dict = {
cnn.input_x: x_text_dev,
cnn.input_y: y_batch,
cnn.dropout_keep_prob: 1.0
}
step, loss, accuracy = sess.run(
[global_step, cnn.loss, cnn.accuracy],
feed_dict)
time_str = datetime.datetime.now().isoformat()
print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
return loss
# if writer:
# writer.add_summary(summaries, step)
batch_iter = CNN.get_batches()
X_val, Y_val = CNN.get_validation_data()
for batch in batch_iter:
loss = accuracy = 0.0
X_train, y_train = zip(*batch)
X_train, Y_train = np.asarray(X_train), np.asarray(y_train)
train_loss = train_step(X_train, Y_train)
current_step = tf.train.global_step(sess, global_step)
if current_step % CNN.FLAGS.evaluate_every == 0:
print("Evaluation:")
test_loss = dev_step(np.asarray(X_val), np.asarray(Y_val))
if abs(test_loss - train_loss) > CNN.FLAGS.early_threshold:
exit(0)
print("")
if current_step % CNN.FLAGS.checkpoint_every == 0:
path = saver.save(sess, checkpoint_prefix, global_step=current_step)
print("Saved model checkpoint to {}\n".format(path))
print("-------------------")
print("Finished in time %0.3f" % (time.time() - start_time))