diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 39df07359c18..bd7482eb50cf 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -473,7 +473,7 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]: if v in res_vars: if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]): results.append((v.aval, f"named '{eqn.params['name']}' from {src}")) - elif str(eqn.primitive) == 'xla_call': + elif str(eqn.primitive) == 'pjit': results.append((v.aval, f"output of jitted function '{eqn.params['name']}' " f"from {src}")) diff --git a/tests/api_test.py b/tests/api_test.py index d1c77c75eb79..c73d5960f123 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5768,6 +5768,27 @@ def f(x, y): self.assertStartsWith(res[4][1], "named 'z'") self.assertEqual(res[5][0].shape, ()) + def test_saved_residuals_utility_jit(self): + @jax.jit + def f(x, y): + x1, x2 = x + z = checkpoint_name(jnp.sin(3.), 'z') + return z * ((x1 * x2) * y) * np.array([3.]) + + res = saved_residuals(f, (2., 3.), y=4.) + self.assertLen(res, 6) + self.assertEqual(res[0][0].shape, ()) + self.assertEqual(res[0][1], "from the argument x[0]") + self.assertEqual(res[1][0].shape, ()) + self.assertEqual(res[1][1], "from the argument x[1]") + self.assertEqual(res[2][0].shape, ()) + self.assertEqual(res[2][1], "from the argument y") + self.assertEqual(res[3][0].shape, ()) + self.assertStartsWith(res[3][1], "output of jitted function 'f'") + self.assertEqual(res[4][0].shape, ()) + self.assertEqual(res[5][0].shape, (1,)) + self.assertStartsWith(res[5][1], "output of jitted function 'f'") + @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [