-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathAttentionUnit.py
54 lines (46 loc) · 2.42 KB
/
AttentionUnit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import tensorflow as tf
import pickle
class AttentionWrapper(object):
def __init__(self, hidden_size, input_size, hs, scope_name):
self.hs = tf.transpose(hs, [1,0,2])
self.hidden_size = hidden_size
self.input_size = input_size
self.scope_name = scope_name
self.params = {}
with tf.variable_scope(scope_name):
self.Wh = tf.get_variable('Wh', [input_size, hidden_size])
self.bh = tf.get_variable('bh', [hidden_size])
self.Ws = tf.get_variable('Ws', [input_size, hidden_size])
self.bs = tf.get_variable('bs', [hidden_size])
self.Wo = tf.get_variable('Wo', [2*input_size, hidden_size])
self.bo = tf.get_variable('bo', [hidden_size])
self.params.update({'Wh': self.Wh, 'Ws': self.Ws, 'Wo': self.Wo,
'bh': self.bh, 'bs': self.bs, 'bo': self.bo})
hs2d = tf.reshape(self.hs, [-1, input_size])
phi_hs2d = tf.tanh(tf.nn.xw_plus_b(hs2d, self.Wh, self.bh))
self.phi_hs = tf.reshape(phi_hs2d, tf.shape(self.hs))
def __call__(self, x, finished = None):
gamma_h = tf.tanh(tf.nn.xw_plus_b(x, self.Ws, self.bs))
'''
weights = tf.reduce_sum(self.phi_hs * gamma_h, reduction_indices=2, keep_dims=True)
context = tf.mod(tf.reduce_max(weights * 1000000 + self.hs, reduction_indices=0), 1000000)
context = tf.stop_gradient(context)
'''
weights = tf.reduce_sum(self.phi_hs * gamma_h, reduction_indices=2, keep_dims=True)
weights = tf.exp(weights - tf.reduce_max(weights, reduction_indices=0, keep_dims=True))
weights = tf.divide(weights, (1e-6 + tf.reduce_sum(weights, reduction_indices=0, keep_dims=True)))
context = tf.reduce_sum(self.hs * weights, reduction_indices=0)
out = tf.tanh(tf.nn.xw_plus_b(tf.concat([context, x], -1), self.Wo, self.bo))
if finished is not None:
out = tf.where(finished, tf.zeros_like(out), out)
return out
def save(self, path):
param_values = {}
for param in self.params:
param_values[param] = self.params[param].eval()
with open(path, 'wb') as f:
pickle.dump(param_values, f, True)
def load(self, path):
param_values = pickle.load(open(path, 'rb'))
for param in param_values:
self.params[param].load(param_values[param])