-
Hello JAX team, is there a way that we can get the 'mlir' or 'differentiation rule' of some JAX code similar to how we get the lowered mlir like so
Basically I am looking to do this from jax.interpreters import batching, ad, mlir
def my_primitive_impl(x):
return x **2
mlir.register_lowering(my_primitive_p,mlir.lower_fun(my_primitive_impl))
ad.primitive_jvps[my_primitive_p] = ad.lower_jvp(my_primitive_impl)
batching.primitive_batchers[my_primitive_p] = batching.lower_batcher(my_primitive_impl) |
Beta Was this translation helpful? Give feedback.
Answered by
dfm
Jul 11, 2024
Replies: 1 comment 3 replies
-
There are ways to achieve this behavior, but it's not always totally straightforward. Can you say more about what your motivation/use case is for this, and perhaps we can work out the best approach? |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I must admit that I don't have a lot of experience with
custom_partitioning
so I can't help too much with that part. But, I can help with the specific question of how to use the existingjax.vjp
implementation within acustom_vjp
. We can update thecustom_vjp
example from the docs to do this:Hope this helps!