forked from bzhangGo/sltunet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
func.py
293 lines (230 loc) · 9.64 KB
/
func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
from utils import util, dtype
def linear(x, dim, bias=True, ln=False,
weight_initializer=None,
bias_initializer=tf.zeros_initializer(),
scope=None):
"""
basic linear or feed forward layer
:param x: input tensor or list
:param dim: output dimension or list
:param bias: whether use bias term
:param ln: whether use layer normalization
:param weight_initializer: you can set it if you want
:param bias_initializer: you can set it if you want
:param scope
:return:
"""
with tf.variable_scope(scope or "linear", values=[x],
dtype=tf.as_dtype(dtype.floatx())):
if not isinstance(x, (list, tuple)):
x = [x]
if not isinstance(dim, (list, tuple)):
dim = [dim]
if not ln:
# by default, we concatenate inputs
x = [tf.concat(x, -1)]
outputs = []
for oidx, osize in enumerate(dim):
results = []
for iidx, ix in enumerate(x):
x_shp = util.shape_list(ix)
xsize = x_shp[-1]
W = tf.get_variable("W_{}_{}".format(oidx, iidx), [xsize, osize], initializer=weight_initializer)
o = tf.matmul(tf.reshape(ix, [-1, xsize]), W)
if ln:
o = layer_norm(o, scope="ln_{}_{}".format(oidx, iidx))
results.append(o)
o = tf.add_n(results)
if bias:
b = tf.get_variable("b_{}".format(oidx), [osize], initializer=bias_initializer)
o = tf.nn.bias_add(o, b)
x_shp = util.shape_list(x[0])[:-1]
o = tf.reshape(o, tf.concat([x_shp, [osize]], 0))
outputs.append(o)
return outputs[0] if len(outputs) == 1 else outputs
def split_heads(inputs, num_heads, name=None):
""" Split heads
:param inputs: A tensor with shape [batch, length, channels]
:param num_heads: An integer
:param name: An optional string
:returns: A tensor with shape [batch, heads, length, channels / heads]
"""
with tf.name_scope(name or "split_heads"):
x = inputs
n = num_heads
old_shape = x.get_shape().dims
last = old_shape[-1]
new_shape = old_shape[:-1] + [n] + [last // n if last else None]
ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0))
ret.set_shape(new_shape)
return tf.transpose(ret, [0, 2, 1, 3])
def combine_heads(inputs, name=None):
""" Combine heads
:param inputs: A tensor with shape [batch, heads, length, channels]
:param name: An optional string
:returns: A tensor with shape [batch, length, heads * channels]
"""
with tf.name_scope(name or "combine_heads"):
x = inputs
x = tf.transpose(x, [0, 2, 1, 3])
old_shape = x.get_shape().dims
a, b = old_shape[-2:]
new_shape = old_shape[:-2] + [a * b if a and b else None]
x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0))
x.set_shape(new_shape)
return x
def dot_attention(query, memory, mem_mask, hidden_size,
ln=False, num_heads=1, cache=None, dropout=None,
out_map=True, scope=None):
"""
dotted attention model
:param query: [batch_size, qey_len, dim]
:param memory: [batch_size, seq_len, mem_dim] or None
:param mem_mask: [batch_size, seq_len]
:param hidden_size: attention space dimension
:param ln: whether use layer normalization
:param num_heads: attention head number
:param dropout: attention dropout, default disable
:param out_map: output additional mapping
:param cache: cache-based decoding
:param scope:
:return: a value matrix, [batch_size, qey_len, mem_dim]
"""
with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE,
dtype=tf.as_dtype(dtype.floatx())):
if memory is None:
# suppose self-attention from queries alone
h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map")
q, k, v = tf.split(h, 3, -1)
if cache is not None:
k = tf.concat([cache['k'], k], axis=1)
v = tf.concat([cache['v'], v], axis=1)
cache = {
'k': k,
'v': v,
}
else:
q = linear(query, hidden_size, ln=ln, scope="q_map")
if cache is not None and ('mk' in cache and 'mv' in cache):
k, v = cache['mk'], cache['mv']
else:
k = linear(memory, hidden_size, ln=ln, scope="k_map")
v = linear(memory, hidden_size, ln=ln, scope="v_map")
if cache is not None:
cache['mk'] = k
cache['mv'] = v
# [bs, len, d] => [bs, h, len, d/h]
q = split_heads(q, num_heads)
k = split_heads(k, num_heads)
v = split_heads(v, num_heads)
q *= (hidden_size // num_heads) ** (-0.5)
# q * k => attention weights
logits = tf.matmul(q, k, transpose_b=True)
if mem_mask is not None:
logits += mem_mask
weights = tf.nn.softmax(logits)
dweights = util.valid_apply_dropout(weights, dropout)
# weights * v => attention vectors
o = tf.matmul(dweights, v)
o = combine_heads(o)
if out_map:
o = linear(o, hidden_size, ln=ln, scope="o_map")
results = {
'weights': weights,
'output': o,
'cache': cache
}
return results
def layer_norm(x, eps=None, scope=None):
"""Layer normalization layer"""
if eps is None:
eps = dtype.epsilon()
with tf.variable_scope(scope or "layer_norm",
dtype=tf.as_dtype(dtype.floatx())):
layer_size = util.shape_list(x)[-1]
scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer())
offset = tf.get_variable("offset", [layer_size], initializer=tf.zeros_initializer())
mean = tf.reduce_mean(x, -1, keep_dims=True)
var = tf.reduce_mean((x - mean) ** 2, -1, keep_dims=True)
return scale * (x - mean) * tf.rsqrt(var + eps) + offset
def rms_norm(x, eps=None, scope=None):
"""RMS-based Layer normalization layer"""
if eps is None:
eps = dtype.epsilon()
with tf.variable_scope(scope or "rms_norm",
dtype=tf.as_dtype(dtype.floatx())):
layer_size = util.shape_list(x)[-1]
scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer())
ms = tf.reduce_mean(x ** 2, -1, keep_dims=True)
return scale * x * tf.rsqrt(ms + eps)
def residual_fn(x, y, dropout=None):
"""Residual Connection"""
y = util.valid_apply_dropout(y, dropout)
return x + y
def ffn_layer(x, d, d_o, dropout=None, scope=None):
"""FFN layer in Transformer"""
with tf.variable_scope(scope or "ffn_layer",
dtype=tf.as_dtype(dtype.floatx())):
hidden = linear(x, d, scope="enlarge")
hidden = tf.nn.relu(hidden)
hidden = util.valid_apply_dropout(hidden, dropout)
output = linear(hidden, d_o, scope="output")
return output
def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4,
time=None, name=None):
"""Transformer Positional Embedding"""
with tf.name_scope(name, default_name="add_timing_signal", values=[x]):
length = tf.shape(x)[1]
channels = tf.shape(x)[2]
if time is None:
position = dtype.tf_to_float(tf.range(length))
else:
# decoding position embedding
position = tf.expand_dims(time, 0)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(dtype.tf_to_float(num_timescales) - 1)
)
inv_timescales = min_timescale * tf.exp(
dtype.tf_to_float(tf.range(num_timescales)) * -log_timescale_increment
)
scaled_time = (tf.expand_dims(position, 1) *
tf.expand_dims(inv_timescales, 0))
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
signal = tf.reshape(signal, [1, length, channels])
return x + signal
def attention_bias(inputs, mode, inf=None, name=None):
""" A bias tensor used in attention mechanism"""
if inf is None:
inf = dtype.inf()
with tf.name_scope(name, default_name="attention_bias", values=[inputs]):
if mode == "causal":
length = inputs
lower_triangle = tf.matrix_band_part(
tf.ones([length, length]), -1, 0
)
ret = dtype.tf_to_float(- inf * (1.0 - lower_triangle))
return tf.reshape(ret, [1, 1, length, length])
elif mode == "masking":
mask = inputs
ret = (1.0 - mask) * - inf
return tf.expand_dims(tf.expand_dims(ret, 1), 1)
elif mode == "aan":
length = tf.shape(inputs)[1]
diagonal = tf.eye(length)
cum_factor = tf.expand_dims(tf.cumsum(diagonal, axis=0), 0)
mask = tf.expand_dims(inputs, 1) * tf.expand_dims(inputs, 2)
mask *= dtype.tf_to_float(cum_factor)
weight = tf.nn.softmax(mask + (1.0 - mask) * - inf)
weight *= mask
return weight
else:
raise ValueError("Unknown mode %s" % mode)