From 50f492b226da5c481c0cc030f96f39f6f3132754 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 19 Jul 2023 16:03:17 +0000 Subject: [PATCH 1/2] fix carry slice logic --- flax/linen/recurrent.py | 8 +++++--- tests/linen/linen_recurrent_test.py | 8 ++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 1ec2bd803a..c4300351da 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -786,6 +786,8 @@ def __call__( else: carry = initial_carry + slice_carry = seq_lengths is not None and return_carry + def scan_fn( cell: RNNCellBase, carry: Carry, x: Array ) -> Union[Tuple[Carry, Array], Tuple[Carry, Tuple[Carry, Array]]]: @@ -794,7 +796,7 @@ def scan_fn( # so that we can select the last carry for each sequence later. # This uses more memory but is faster than using jnp.where at each # iteration. As a small optimization do this when we really need it. - if seq_lengths is not None and return_carry: + if slice_carry: return carry, (carry, y) else: return carry, y @@ -802,7 +804,7 @@ def scan_fn( scan = transforms.scan( scan_fn, in_axes=time_axis, - out_axes=time_axis if seq_lengths is None else (0, time_axis), + out_axes=(0, time_axis) if slice_carry else time_axis, unroll=self.unroll, variable_axes=self.variable_axes, variable_broadcast=self.variable_broadcast, @@ -815,7 +817,7 @@ def scan_fn( # Next we select the final carry. If a segmentation mask was provided and # return_carry is True we slice the carry history and select the last valid # carry for each sequence. Otherwise we just use the last carry. - if seq_lengths is not None and return_carry: + if slice_carry: _, (carries, outputs) = scan_output # seq_lengths[None] expands the shape of the mask to match the # number of dimensions of the carry. diff --git a/tests/linen/linen_recurrent_test.py b/tests/linen/linen_recurrent_test.py index d59829b19e..5a09b9cd87 100644 --- a/tests/linen/linen_recurrent_test.py +++ b/tests/linen/linen_recurrent_test.py @@ -369,6 +369,14 @@ def test_flip_sequence_time_major_more_feature_dims(self): np.testing.assert_allclose(flipped[:4, 0], x[:4, 0][::-1]) np.testing.assert_allclose(flipped[:2, 1], x[:2, 1][::-1]) + def test_basic_seq_lengths(self): + + x = jnp.ones((2, 10, 6)) + lstm = nn.RNN(nn.LSTMCell(265)) + variables = lstm.init(jax.random.PRNGKey(0), x) + y = lstm.apply(variables, x, seq_lengths=jnp.array([5, 5])) + + class BidirectionalTest(absltest.TestCase): def test_bidirectional(self): From 68b334e78f78b8afbefe0280f291b8a531f8314d Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 19 Jul 2023 18:38:54 +0000 Subject: [PATCH 2/2] fix mypy issue --- flax/linen/recurrent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index c4300351da..a00dfdb540 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -818,6 +818,7 @@ def scan_fn( # return_carry is True we slice the carry history and select the last valid # carry for each sequence. Otherwise we just use the last carry. if slice_carry: + assert seq_lengths is not None _, (carries, outputs) = scan_output # seq_lengths[None] expands the shape of the mask to match the # number of dimensions of the carry.