Skip to content

custom_jvp crash on my function #7092

Answered by jakevdp
mariogeiger asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the report and for the clear repro - I think you're running into the issue mentioned in this comment: https://github.com/google/jax/blob/97a5719fcb40af7231b5f803f965063538282f8e/jax/interpreters/ad.py#L198-L200 This is a real bug and needs to be fixed.

One thing to note: the jax.jacobian function aliases to jax.jacrev, meaning that it uses reverse-mode autodiff, while your JVP implements forward-mode autodiff. JAX should be able to construct the backward rule automatically here, except for that bug in the backward pass above.

Until that is fixed, there are a couple possibilities:

  • you may be able to limit your use-case to forward mode autodiff; i.e. call jax.jacfwd instead of j…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@mariogeiger
Comment options

Answer selected by mariogeiger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants