-
Notifications
You must be signed in to change notification settings - Fork 0
/
att_self.py
50 lines (37 loc) · 1.86 KB
/
att_self.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import tensorflow as tf
from arenets.attention import common
from arenets.attention.architectures.self_p_zhou import self_attention_by_peng_zhou, \
calculate_sequential_attentive_weights_by_peng_zhou
from arenets.multi.architectures.base.base_single_mlp import BaseMultiInstanceSingleMLP
class AttSelfOverSentences(BaseMultiInstanceSingleMLP):
"""
Utilize sequence-based attention architectures.
"""
def __init__(self, context_network):
super(AttSelfOverSentences, self).__init__(context_network)
self.__att_alphas = None
@property
def EmbeddingSize(self):
return self.ContextNetwork.ContextEmbeddingSize
def init_multiinstance_embedding(self, context_outputs):
""" input:
context_outputs: Tensor
[batches, sentences, embedding]
output: Tensor
tensor of shape [batches, embedding]
"""
context_outputs = tf.reshape(context_outputs, [self.Config.BatchSize,
self.Config.BagSize,
self.ContextNetwork.ContextEmbeddingSize])
with tf.variable_scope("mi_{}".format(common.ATTENTION_SCOPE_NAME)):
self.__att_alphas = self_attention_by_peng_zhou(context_outputs)
# (B, T, D) -> (B, D)
att_output = calculate_sequential_attentive_weights_by_peng_zhou(inputs=context_outputs,
alphas=self.__att_alphas)
return att_output
# region 'iter' methods
def iter_input_dependent_hidden_parameters(self):
for name, value in super(AttSelfOverSentences, self).iter_input_dependent_hidden_parameters():
yield name, value
yield common.ATTENTION_WEIGHTS_LOG_PARAMETER, self.__att_alphas
# endregion