From b3e3115391a9cf2373cf8f3ed1f68029ce956e60 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sat, 24 Aug 2024 22:39:28 -0700 Subject: [PATCH] improve `scan` error message on non-concrete `unroll` argument --- jax/_src/lax/control_flow/loops.py | 4 ++++ tests/lax_control_flow_test.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d9cf1f89c5c4..f7f09424a9e8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -296,6 +296,10 @@ def _create_jaxpr(init): raise NotImplementedError( f'Effects not supported in `scan`: {disallowed_effects}') + unroll = core.concrete_or_error( + None, unroll, + "The `unroll` argument to `scan` expects a concrete `int` or `bool` " + "value.") if isinstance(unroll, bool): unroll = max(length, 1) if unroll else 1 if unroll < 1: diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 829169b40778..fd83d269b41c 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2958,6 +2958,17 @@ def test_scan_length_concrete_error(self): "The `length` argument to `scan` expects a concrete `int` value.*"): f(3, 1.) + def test_scan_unroll_concrete_error(self): + f = jax.jit(lambda n, x: jax.lax.scan( + lambda c, z: (c, z), x, (), 10, unroll=n)) + + msg = ("The `unroll` argument to `scan` expects a concrete `int` or " + "`bool` value.*") + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(3, 1.) + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(True, 1.) + def test_cond_vmap_forwarding_doesnt_promote(self): def f(x, y): x, y = jax.lax.cond(