-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention.py
executable file
·78 lines (68 loc) · 3.78 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer
# KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode.
# In debug mode, the class Attention is no longer a Keras layer.
# What it means in practice is that we can have access to the internal values
# of each tensor. If we don't use debug, Keras treats the object
# as a layer and we can only get the final output.
debug_flag = int(os.environ.get('KERAS_ATTENTION_DEBUG', 0))
class Attention(object if debug_flag else Layer):
def __init__(self, units=128, **kwargs):
super(Attention, self).__init__(**kwargs)
self.units = units
# noinspection PyAttributeOutsideInit
def build(self, input_shape):
input_dim = int(input_shape[-1])
with K.name_scope(self.name if not debug_flag else 'attention'):
self.attention_score_vec = Dense(input_dim, use_bias=False, name='attention_score_vec')
self.h_t = Lambda(lambda x: x[:, -1, :], output_shape=(input_dim,), name='last_hidden_state')
self.attention_score = Dot(axes=[1, 2], name='attention_score')
self.attention_weight = Activation('softmax', name='attention_weight')
self.context_vector = Dot(axes=[1, 1], name='context_vector')
self.attention_output = Concatenate(name='attention_output')
self.attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')
if not debug_flag:
# debug: the call to build() is done in call().
super(Attention, self).build(input_shape)
def compute_output_shape(self, input_shape):
return input_shape[0], self.units
def __call__(self, inputs, training=None, **kwargs):
if debug_flag:
return self.call(inputs, training, **kwargs)
else:
return super(Attention, self).__call__(inputs, training, **kwargs)
# noinspection PyUnusedLocal
def call(self, inputs, training=None, **kwargs):
"""
Many-to-one attention mechanism for Keras.
@param inputs: 3D tensor with shape (batch_size, time_steps, input_dim).
@param training: not used in this layer.
@return: 2D tensor with shape (batch_size, units)
@author: felixhao28, philipperemy.
"""
if debug_flag:
self.build(inputs.shape)
# Inside dense layer
# hidden_states dot W => score_first_part
# (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size)
# W is the trainable weight matrix of attention Luong's multiplicative style score
score_first_part = self.attention_score_vec(inputs)
# score_first_part dot last_hidden_state => attention_weights
# (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps)
h_t = self.h_t(inputs)
score = self.attention_score([h_t, score_first_part])
attention_weights = self.attention_weight(score)
# (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size)
context_vector = self.context_vector([inputs, attention_weights])
pre_activation = self.attention_output([context_vector, h_t])
attention_vector = self.attention_vector(pre_activation)
return attention_vector
def get_config(self):
"""
Returns the config of a the layer. This is used for saving and loading from a model
:return: python dictionary with specs to rebuild layer
"""
config = super(Attention, self).get_config()
config.update({'units': self.units})
return