This repository has been archived by the owner on Nov 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 38
/
model_utils.py
159 lines (131 loc) · 6.35 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import math
import tensorflow as tf
def variable_summaries(var, groupname, name):
"""Attach a lot of summaries to a Tensor.
This is also quite expensive.
"""
with tf.device("/cpu:0"), tf.name_scope(None):
s_var = tf.cast(var, tf.float32)
amean = tf.reduce_mean(tf.abs(s_var))
tf.summary.scalar(groupname + '/amean/' + name, amean)
mean = tf.reduce_mean(s_var)
tf.summary.scalar(groupname + '/mean/' + name, mean)
stddev = tf.sqrt(tf.reduce_sum(tf.square(s_var - mean)))
tf.summary.scalar(groupname + '/sttdev/' + name, stddev)
tf.summary.scalar(groupname + '/max/' + name, tf.reduce_max(s_var))
tf.summary.scalar(groupname + '/min/' + name, tf.reduce_min(s_var))
tf.summary.histogram(groupname + "/" + name, s_var)
def getdtype(hps, is_rnn=False):
if is_rnn:
return tf.float16 if hps.float16_rnn else tf.float32
else:
return tf.float16 if hps.float16_non_rnn else tf.float32
def linear(x, size, name):
w = tf.get_variable(name + "/W", [x.get_shape()[-1], size])
b = tf.get_variable(name + "/b", [1, size], initializer=tf.zeros_initializer)
return tf.matmul(x, w) + b
def sharded_variable(name, shape, num_shards, dtype=tf.float32, transposed=False):
# The final size of the sharded variable may be larger than requested.
# This should be fine for embeddings.
shard_size = int((shape[0] + num_shards - 1) / num_shards)
if transposed:
initializer = tf.uniform_unit_scaling_initializer(dtype=dtype)
else:
initializer = tf.uniform_unit_scaling_initializer(dtype=dtype)
return [tf.get_variable(name + "_" + str(i), [shard_size, shape[1]],
initializer=initializer, dtype=dtype) for i in range(num_shards)]
# XXX(rafal): Code below copied from rnn_cell.py
def _get_sharded_variable(name, shape, dtype, num_shards):
"""Get a list of sharded variables with the given dtype."""
if num_shards > shape[0]:
raise ValueError("Too many shards: shape=%s, num_shards=%d" %
(shape, num_shards))
unit_shard_size = int(math.floor(shape[0] / num_shards))
remaining_rows = shape[0] - unit_shard_size * num_shards
shards = []
for i in range(num_shards):
current_size = unit_shard_size
if i < remaining_rows:
current_size += 1
shards.append(tf.get_variable(name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
return shards
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
_sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(_sharded_variable) == 1:
return _sharded_variable[0]
return tf.concat(_sharded_variable, 0)
class FLSTMCell(tf.contrib.rnn.RNNCell):
"""LSTMCell with factorized matrix"""
def __init__(self, num_units, input_size, initializer=None,
num_proj=None, num_shards=1, factor_size=None, fnon_linearity=None, dtype=tf.float32):
self._num_units = num_units
self._initializer = initializer
self._num_proj = num_proj
self._num_unit_shards = num_shards
self._num_proj_shards = num_shards
self._forget_bias = 1.0
if factor_size:
self._factor_size = int(factor_size)
else:
self._factor_size = None
self._fnon_linearity = fnon_linearity
if num_proj:
self._state_size = num_units + num_proj
self._output_size = num_proj
else:
self._state_size = 2 * num_units
self._output_size = num_units
with tf.variable_scope("LSTMCell"):
if self._factor_size:
self._concat_w1 = _get_concat_variable(
"W1", [input_size + num_proj, self._factor_size],
dtype, self._num_unit_shards)
self._concat_w2 = _get_concat_variable(
"W2", [self._factor_size, 4 * self._num_units],
dtype, self._num_unit_shards)
if self._fnon_linearity:
self._b1 = tf.get_variable(name="b1", shape=[self._factor_size])
else:
self._concat_w = _get_concat_variable(
"W", [input_size + num_proj, 4 * self._num_units],
dtype, self._num_unit_shards)
self._b = tf.get_variable(
"B", shape=[4 * self._num_units])
self._concat_w_proj = _get_concat_variable(
"W_P", [self._num_units, self._num_proj],
dtype, self._num_proj_shards)
@property
def state_size(self):
return self._state_size
@property
def output_size(self):
return self._output_size
def __call__(self, inputs, state, scope=None):
num_proj = self._num_units if self._num_proj is None else self._num_proj
c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
with tf.variable_scope(type(self).__name__,
initializer=self._initializer): # "LSTMCell"
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
cell_inputs = tf.concat([inputs, m_prev], 1)
if self._factor_size:
if self._fnon_linearity:
lstm_matrix = tf.nn.bias_add(tf.matmul(
self._fnon_linearity(tf.nn.bias_add(tf.matmul(cell_inputs, self._concat_w1), self._b1)),
self._concat_w2), self._b)
else:
lstm_matrix = tf.nn.bias_add(tf.matmul(tf.matmul(cell_inputs, self._concat_w1), self._concat_w2),
self._b)
else:
lstm_matrix = tf.matmul(cell_inputs, self._concat_w) + self._b
i, j, f, o = tf.split(lstm_matrix, 4, 1)
c = tf.sigmoid(f + 1.0) * c_prev + tf.sigmoid(i) * tf.tanh(j)
m = tf.sigmoid(o) * tf.tanh(c)
if self._num_proj is not None:
m = tf.matmul(m, self._concat_w_proj)
new_state = tf.concat([c, m], 1)
return m, new_state