diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index d27b0efc7e5e..56accc273dbf 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1451,7 +1451,9 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: # above and it would be good to consolidate it. primal_name = getattr(fun, "__name__", str(fun)) fwd_name = getattr(fwd, "__name__", str(fwd)) - args = _resolve_kwargs(fwd, args, kwargs) + # Note: we use `fun` instead of `fwd` here for consistency with + # custom_vjp.__call__ above. + args = _resolve_kwargs(fun, args, kwargs) if nondiff_argnums: for i in nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums_ = set(nondiff_argnums) diff --git a/tests/api_test.py b/tests/api_test.py index 15c4c8e7ae4f..fa2f389f494b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9762,6 +9762,23 @@ def f_bwd(res, g): x, y = 3.2, 1.0 self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) + def test_optimize_remat_kwargs(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + def f_fwd(x, y, *, keyword=False): + del keyword + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + jax.grad(f)(x, y) # Doesn't error + def transpose_unary(f, x_example): def transposed(y):