Skip to content

Commit

Permalink
Merge pull request #3213 from google:fix-rnn-issue
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549406182
  • Loading branch information
Flax Authors committed Jul 19, 2023
2 parents 15d6857 + 68b334e commit b234192
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
9 changes: 6 additions & 3 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,8 @@ def __call__(
else:
carry = initial_carry

slice_carry = seq_lengths is not None and return_carry

def scan_fn(
cell: RNNCellBase, carry: Carry, x: Array
) -> Union[Tuple[Carry, Array], Tuple[Carry, Tuple[Carry, Array]]]:
Expand All @@ -794,15 +796,15 @@ def scan_fn(
# so that we can select the last carry for each sequence later.
# This uses more memory but is faster than using jnp.where at each
# iteration. As a small optimization do this when we really need it.
if seq_lengths is not None and return_carry:
if slice_carry:
return carry, (carry, y)
else:
return carry, y

scan = transforms.scan(
scan_fn,
in_axes=time_axis,
out_axes=time_axis if seq_lengths is None else (0, time_axis),
out_axes=(0, time_axis) if slice_carry else time_axis,
unroll=self.unroll,
variable_axes=self.variable_axes,
variable_broadcast=self.variable_broadcast,
Expand All @@ -815,7 +817,8 @@ def scan_fn(
# Next we select the final carry. If a segmentation mask was provided and
# return_carry is True we slice the carry history and select the last valid
# carry for each sequence. Otherwise we just use the last carry.
if seq_lengths is not None and return_carry:
if slice_carry:
assert seq_lengths is not None
_, (carries, outputs) = scan_output
# seq_lengths[None] expands the shape of the mask to match the
# number of dimensions of the carry.
Expand Down
8 changes: 8 additions & 0 deletions tests/linen/linen_recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,14 @@ def test_flip_sequence_time_major_more_feature_dims(self):
np.testing.assert_allclose(flipped[:4, 0], x[:4, 0][::-1])
np.testing.assert_allclose(flipped[:2, 1], x[:2, 1][::-1])

def test_basic_seq_lengths(self):

x = jnp.ones((2, 10, 6))
lstm = nn.RNN(nn.LSTMCell(265))
variables = lstm.init(jax.random.PRNGKey(0), x)
y = lstm.apply(variables, x, seq_lengths=jnp.array([5, 5]))


class BidirectionalTest(absltest.TestCase):

def test_bidirectional(self):
Expand Down

0 comments on commit b234192

Please sign in to comment.