custom_jvp crash on my function #7092
-
I my code I have a function on which I want to define custom jvp because it is linear in his 3 arguments. Edit0: same error without constant tensor Here is my minimal code: import jax
import jax.numpy as jnp
@jax.custom_jvp
def f(a, b):
a = a.reshape(-1, 1)
b = b.reshape(1, -1)
return a * b
f.defjvps(
lambda t, _, a, b: f(t, b),
lambda t, _, a, b: f(a, t),
)
x = jnp.ones((3,))
jax.jacobian(f, 0)(x, x) The error I get is:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the report and for the clear repro - I think you're running into the issue mentioned in this comment: https://github.com/google/jax/blob/97a5719fcb40af7231b5f803f965063538282f8e/jax/interpreters/ad.py#L198-L200 This is a real bug and needs to be fixed. One thing to note: the Until that is fixed, there are a couple possibilities:
I opened a bug to track this in #7098 |
Beta Was this translation helpful? Give feedback.
Thanks for the report and for the clear repro - I think you're running into the issue mentioned in this comment: https://github.com/google/jax/blob/97a5719fcb40af7231b5f803f965063538282f8e/jax/interpreters/ad.py#L198-L200 This is a real bug and needs to be fixed.
One thing to note: the
jax.jacobian
function aliases tojax.jacrev
, meaning that it uses reverse-mode autodiff, while your JVP implements forward-mode autodiff. JAX should be able to construct the backward rule automatically here, except for that bug in the backward pass above.Until that is fixed, there are a couple possibilities:
jax.jacfwd
instead ofj…