Skip to content

Commit

Permalink
Fix printing of saved_residual for jit by looking for pjit as the…
Browse files Browse the repository at this point in the history
… primitive instead of `xla_call` which was removed 2 years ago

PiperOrigin-RevId: 678479141
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Sep 25, 2024
1 parent cfb4e85 commit 1fe0c5d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
if v in res_vars:
if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]):
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
elif str(eqn.primitive) == 'xla_call':
elif str(eqn.primitive) == 'pjit':
results.append((v.aval,
f"output of jitted function '{eqn.params['name']}' "
f"from {src}"))
Expand Down
21 changes: 21 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5768,6 +5768,27 @@ def f(x, y):
self.assertStartsWith(res[4][1], "named 'z'")
self.assertEqual(res[5][0].shape, ())

def test_saved_residuals_utility_jit(self):
@jax.jit
def f(x, y):
x1, x2 = x
z = checkpoint_name(jnp.sin(3.), 'z')
return z * ((x1 * x2) * y) * np.array([3.])

res = saved_residuals(f, (2., 3.), y=4.)
self.assertLen(res, 6)
self.assertEqual(res[0][0].shape, ())
self.assertEqual(res[0][1], "from the argument x[0]")
self.assertEqual(res[1][0].shape, ())
self.assertEqual(res[1][1], "from the argument x[1]")
self.assertEqual(res[2][0].shape, ())
self.assertEqual(res[2][1], "from the argument y")
self.assertEqual(res[3][0].shape, ())
self.assertStartsWith(res[3][1], "output of jitted function 'f'")
self.assertEqual(res[4][0].shape, ())
self.assertEqual(res[5][0].shape, (1,))
self.assertStartsWith(res[5][1], "output of jitted function 'f'")

@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
Expand Down

0 comments on commit 1fe0c5d

Please sign in to comment.