Skip to content

Commit

Permalink
improve scan error message on non-concrete unroll argument
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Aug 25, 2024
1 parent e3e0860 commit b3e3115
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ def _create_jaxpr(init):
raise NotImplementedError(
f'Effects not supported in `scan`: {disallowed_effects}')

unroll = core.concrete_or_error(
None, unroll,
"The `unroll` argument to `scan` expects a concrete `int` or `bool` "
"value.")
if isinstance(unroll, bool):
unroll = max(length, 1) if unroll else 1
if unroll < 1:
Expand Down
11 changes: 11 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2958,6 +2958,17 @@ def test_scan_length_concrete_error(self):
"The `length` argument to `scan` expects a concrete `int` value.*"):
f(3, 1.)

def test_scan_unroll_concrete_error(self):
f = jax.jit(lambda n, x: jax.lax.scan(
lambda c, z: (c, z), x, (), 10, unroll=n))

msg = ("The `unroll` argument to `scan` expects a concrete `int` or "
"`bool` value.*")
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(3, 1.)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(True, 1.)

def test_cond_vmap_forwarding_doesnt_promote(self):
def f(x, y):
x, y = jax.lax.cond(
Expand Down

0 comments on commit b3e3115

Please sign in to comment.