Skip to content

Gradients of functions which accept Pytrees with arbitrary leaf types (how to ignore Int)? #14035

Answered by jakevdp
femtomc asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the question – unfortunately there's not a great solution for this at the moment. We encountered a similar problem in jax.experimental.sparse, in which we want autodiff to operate with respect to the data buffer, but not the indices buffer. Our solution there was to provide sparse-specific autodiff wrappers that know how to ignore the integer components in the sparse pytree: http://go/jax-github/blob/main/jax/experimental/sparse/ad.py

We've been talking about ways to improve on this, for example by allowing pytrees to register their preferred autodiff behavior in the same way that register_vmappable allows user-defined vmap behavior (see how it's used in e.g. http://go/jax-gith…

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
3 replies
@femtomc
Comment options

@femtomc
Comment options

@jakevdp
Comment options

Answer selected by femtomc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants