-
Notifications
You must be signed in to change notification settings - Fork 2
/
training_graph.py
69 lines (50 loc) · 1.91 KB
/
training_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from enum import Enum
import tensorflow as tf
print(tf.__version__)
class RunMode(Enum):
training = 1
validation = 2
prediction = 3
class TrainingGraph():
def __init__(self, composition_model, batch_size, learning_rate, run_mode, alpha=0.0):
self._model = composition_model
self._is_training = self.model._is_training
self._original_vector = tf.placeholder(dtype=tf.int64, shape=[batch_size])
original_embeddings = tf.nn.embedding_lookup(params=self._model.lookup, ids=self._original_vector)
self._predictions = self._model._architecture_normalized
self._loss = self._reg_loss = tf.losses.cosine_distance(labels=original_embeddings,
predictions=self._predictions, axis=1, reduction=tf.losses.Reduction.SUM)
self._train_op = tf.no_op()
if run_mode is RunMode.training:
if alpha > 0.0:
self._reg_loss += alpha*composition_model.regularization()
self._train_op = tf.train.AdagradOptimizer(learning_rate=learning_rate).minimize(self._reg_loss)
else:
self._train_op = tf.train.AdagradOptimizer(learning_rate=learning_rate).minimize(self._loss)
#summaries for tensorboard
tf.summary.scalar("learning rate", learning_rate)
tf.summary.scalar("loss", self._loss)
@property
def is_training(self):
return self._is_training
@property
def model(self):
return self._model
@property
def original_vector(self):
return self._original_vector
@property
def architecture(self):
return self._architecture
@property
def loss(self):
return self._loss
@property
def reg_loss(self):
return self._reg_loss
@property
def predictions(self):
return self._predictions
@property
def train_op(self):
return self._train_op