Skip to content

Lower 'python' JAX code to MLIR #22393

Closed Answered by dfm
ASKabalan asked this question in Q&A
Jul 11, 2024 · 1 comments · 3 replies
Discussion options

You must be logged in to vote

I must admit that I don't have a lot of experience with custom_partitioning so I can't help too much with that part. But, I can help with the specific question of how to use the existing jax.vjp implementation within a custom_vjp. We can update the custom_vjp example from the docs to do this:

import jax
import jax.numpy as jnp

def f_ref(x, y):
  return jnp.sin(x) * y

@jax.custom_vjp
def f(x, y):
  return f_ref(x, y)

def f_fwd(x, y):
  return f(x, y), (x, y)  # The residuals should _be_ the primals

def f_bwd(res, g):
  _, vjp_fun = jax.vjp(f_ref, *res)
  return vjp_fun(g)

f.defvjp(f_fwd, f_bwd)

Hope this helps!

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@ASKabalan
Comment options

@dfm
Comment options

Answer selected by ASKabalan
@ASKabalan
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants