This repository has been archived by the owner on Nov 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 38
/
common.py
96 lines (81 loc) · 3.11 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import time
import numpy as np
import tensorflow as tf
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
def _assign(op):
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
if node_def.op == "Variable":
return ps_dev
else:
return "/gpu:%d" % gpu
return _assign
def find_trainable_variables(key):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, ".*{}.*".format(key))
def load_from_checkpoint(saver, logdir):
sess = tf.get_default_session()
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
if os.path.isabs(ckpt.model_checkpoint_path):
# Restores from checkpoint with absolute path.
saver.restore(sess, ckpt.model_checkpoint_path)
else:
# Restores from checkpoint with relative path.
saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path))
return True
return False
class CheckpointLoader(object):
def __init__(self, saver, global_step, logdir):
self.saver = saver
self.global_step_tensor = global_step
self.logdir = logdir
# TODO(rafal): make it restart-proof?
self.last_global_step = 0
def load_checkpoint(self):
while True:
if load_from_checkpoint(self.saver, self.logdir):
global_step = int(self.global_step_tensor.eval())
if global_step <= self.last_global_step:
time.sleep(60)
continue
print("Succesfully loaded model at step=%s." % global_step)
else:
print("No checkpoint file found. Waiting...")
time.sleep(60)
continue
self.last_global_step = global_step
return True
def average_grads(tower_grads):
def average_dense(grad_and_vars):
if len(grad_and_vars) == 1:
return grad_and_vars[0][0]
grad = grad_and_vars[0][0]
for g, _ in grad_and_vars[1:]:
grad += g
return grad / len(grad_and_vars)
def average_sparse(grad_and_vars):
if len(grad_and_vars) == 1:
return grad_and_vars[0][0]
indices = []
values = []
for g, _ in grad_and_vars:
indices += [g.indices]
values += [g.values]
indices = tf.concat(indices, 0)
values = tf.concat(values, 0)
return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape)
average_grads = []
for grad_and_vars in zip(*tower_grads):
if grad_and_vars[0][0] is None:
grad = None
elif isinstance(grad_and_vars[0][0], tf.IndexedSlices):
grad = average_sparse(grad_and_vars)
else:
grad = average_dense(grad_and_vars)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads