From ef82cb21aeed3363330b0c5454722664fc3222a4 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 19 Jul 2024 00:24:25 +0000 Subject: [PATCH] fix basic scan bug with attrs --- jax/_src/lax/control_flow/loops.py | 3 ++- tests/attrs_test.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index aa707386c5db..a2a6d71f55d6 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -268,7 +268,8 @@ def _create_jaxpr(init): if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves] + _, carry_avals_out, _ = split_list( + jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves]) return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 5c834f314270..4378a3c7526d 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -344,6 +344,21 @@ def jitted(): jax.jit(jitted)() # don't crash + def test_scan_carry(self): + class A: + ... + + a = A() + + jax_setattr(a, 'x', jnp.zeros(3)) + + def body(i, _): + x = jax_getattr(a, 'x') + x = x.at[i].set(x[i] + 1) + jax_setattr(a, 'x', x) + return i + 1, None + _, _ = jax.lax.scan(body, 0, None, length=3) # don't crash + class AttrsJVPTest(jtu.JaxTestCase):