Skip to content

Commit

Permalink
Remove unnecessary constraint on keyword-only arguments in `custom_vj…
Browse files Browse the repository at this point in the history
…p` with `optimize_remat=True`.

PiperOrigin-RevId: 660945559
  • Loading branch information
dfm authored and jax authors committed Aug 8, 2024
1 parent 93d4629 commit efb7721
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
4 changes: 3 additions & 1 deletion jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,9 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
# above and it would be good to consolidate it.
primal_name = getattr(fun, "__name__", str(fun))
fwd_name = getattr(fwd, "__name__", str(fwd))
args = _resolve_kwargs(fwd, args, kwargs)
# Note: we use `fun` instead of `fwd` here for consistency with
# custom_vjp.__call__ above.
args = _resolve_kwargs(fun, args, kwargs)
if nondiff_argnums:
for i in nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums_ = set(nondiff_argnums)
Expand Down
17 changes: 17 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9762,6 +9762,23 @@ def f_bwd(res, g):
x, y = 3.2, 1.0
self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y))

def test_optimize_remat_kwargs(self):
@jax.custom_vjp
def f(x, y):
return jnp.sin(x) * y

def f_fwd(x, y, *, keyword=False):
del keyword
return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd, optimize_remat=True)
x, y = 3.2, 1.0
jax.grad(f)(x, y) # Doesn't error


def transpose_unary(f, x_example):
def transposed(y):
Expand Down

0 comments on commit efb7721

Please sign in to comment.