-
Notifications
You must be signed in to change notification settings - Fork 0
/
rnn_util.py
164 lines (139 loc) · 6.97 KB
/
rnn_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import tensorflow as tf
from tensorflow.python.ops import variable_scope as vs
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple, LayerNormBasicLSTMCell
from tensorflow.python.ops import array_ops
def _like_rnncell(cell):
"""Checks that a given object is an RNNCell by using duck typing."""
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
hasattr(cell, "zero_state"), callable(cell)]
return all(conditions)
class StatefulLayerNormBasicLSTMCell(LayerNormBasicLSTMCell):
def __init__(self, *args, **kwargs):
super(StatefulLayerNormBasicLSTMCell, self).__init__(*args, **kwargs)
@property
def output_size(self):
return (self.state_size, super(StatefulLayerNormBasicLSTMCell, self).output_size)
def call(self, input, state):
output, next_state = super(StatefulLayerNormBasicLSTMCell, self).call(input, state)
emit_output = (next_state, output)
return emit_output, next_state
class StatefulLSTMCell(LSTMCell):
def __init__(self, *args, **kwargs):
super(StatefulLSTMCell, self).__init__(*args, **kwargs)
@property
def output_size(self):
return (self.state_size, super(StatefulLSTMCell, self).output_size)
def call(self, input, state):
output, next_state = super(StatefulLSTMCell, self).call(input, state)
emit_output = (next_state, output)
return emit_output, next_state
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
initial_state_fw=None, initial_state_bw=None,
dtype=None, parallel_iterations=None,
swap_memory=False, time_major=False, scope=None):
if not _like_rnncell(cell_fw):
raise TypeError("cell_fw must be an instance of RNNCell")
if not _like_rnncell(cell_bw):
raise TypeError("cell_bw must be an instance of RNNCell")
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = tf.nn.dynamic_rnn(
cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
initial_state=initial_state_fw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=fw_scope)
# Backward direction
if not time_major:
time_dim = 1
batch_dim = 0
else:
time_dim = 0
batch_dim = 1
def _reverse(input_, seq_lengths, seq_dim, batch_dim):
if seq_lengths is not None:
return array_ops.reverse_sequence(
input=input_, seq_lengths=seq_lengths,
seq_dim=seq_dim, batch_dim=batch_dim)
else:
return array_ops.reverse(input_, axis=[seq_dim])
with vs.variable_scope("bw") as bw_scope:
inputs_reverse = _reverse(
inputs, seq_lengths=sequence_length,
seq_dim=time_dim, batch_dim=batch_dim)
tmp, output_state_bw = tf.nn.dynamic_rnn(
cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
initial_state=initial_state_bw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=bw_scope)
# reverse backword results
output_bw_states_rev, output_bw_output_rev = tmp
output_bw_output = _reverse(
output_bw_output_rev, seq_lengths=sequence_length,
seq_dim=time_dim, batch_dim=batch_dim)
output_bw_states_c = _reverse(
output_bw_states_rev.c, seq_lengths=sequence_length,
seq_dim=time_dim, batch_dim=batch_dim)
output_bw_states = LSTMStateTuple(h=output_bw_output,
c=output_bw_states_c)
output_bw = (output_bw_states, output_bw_output)
# merge outputs
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
return (outputs, output_states)
def stack_bidirectional_dynamic_rnn(cells_fw,
cells_bw,
inputs,
initial_states_fw=None,
initial_states_bw=None,
dtype=None,
sequence_length=None,
parallel_iterations=None,
time_major=False,
scope=None):
if not cells_fw:
raise ValueError("Must specify at least one fw cell for BidirectionalRNN.")
if not cells_bw:
raise ValueError("Must specify at least one bw cell for BidirectionalRNN.")
if not isinstance(cells_fw, list):
raise ValueError("cells_fw must be a list of RNNCells (one per layer).")
if not isinstance(cells_bw, list):
raise ValueError("cells_bw must be a list of RNNCells (one per layer).")
if len(cells_fw) != len(cells_bw):
raise ValueError("Forward and Backward cells must have the same depth.")
if (initial_states_fw is not None and
(not isinstance(initial_states_fw, list) or
len(initial_states_fw) != len(cells_fw))):
raise ValueError(
"initial_states_fw must be a list of state tensors (one per layer).")
if (initial_states_bw is not None and
(not isinstance(initial_states_bw, list) or
len(initial_states_bw) != len(cells_bw))):
raise ValueError(
"initial_states_bw must be a list of state tensors (one per layer).")
prev_layer_h = inputs
prev_layer_c = inputs
with vs.variable_scope(scope or "stack_bidirectional_rnn"):
for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
initial_state_fw = None
initial_state_bw = None
if initial_states_fw:
initial_state_fw = initial_states_fw[i]
if initial_states_bw:
initial_state_bw = initial_states_bw[i]
with vs.variable_scope("cell_%d" % i):
outputs, (state_fw, state_bw) = bidirectional_dynamic_rnn(
cell_fw,
cell_bw,
prev_layer_h,
initial_state_fw=initial_state_fw,
initial_state_bw=initial_state_bw,
sequence_length=sequence_length,
parallel_iterations=parallel_iterations,
dtype=dtype,
time_major=time_major)
# Concat the outputs to create the new input.
(output_fw_states, _), (output_bw_states, _) = outputs
prev_layer_h = array_ops.concat([output_fw_states.h, output_bw_states.h], 2)
prev_layer_c = array_ops.concat([output_fw_states.c, output_bw_states.c], 2)
return prev_layer_h, prev_layer_c