From 68b334e78f78b8afbefe0280f291b8a531f8314d Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 19 Jul 2023 18:38:54 +0000 Subject: [PATCH] fix mypy issue --- flax/linen/recurrent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index c4300351da..a00dfdb540 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -818,6 +818,7 @@ def scan_fn( # return_carry is True we slice the carry history and select the last valid # carry for each sequence. Otherwise we just use the last carry. if slice_carry: + assert seq_lengths is not None _, (carries, outputs) = scan_output # seq_lengths[None] expands the shape of the mask to match the # number of dimensions of the carry.