Skip to content

Jax equivalent to PyTorch "needs_input_grad" for custom vjp/gradient? #16355

Answered by anh-tong
amifalk asked this question in Q&A
Discussion options

You must be logged in to vote

Oh, in this case, you can do something like this

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.custom_vjp, nondiff_argnums=(3,))
def f(x, y, z, needs_input_grad=(True, True, True)):
    return x + y + z


def f_fwd(x, y, z, needs_input_grad=(True, True, True)):
    res = (x, y, z)
    return x + y + z, res

def f_bwd(needs_input_grad, res, g):    
    x, y, z = res
    dx = dy = dz = None
    if needs_input_grad[0]:
        dx = jnp.ones_like(x)
    if needs_input_grad[1]:
        dy = jnp.ones_like(y)
    if needs_input_grad[2]:
        dz = jnp.ones_like(z)
    
    return dx, dy, dz

f.defvjp(f_fwd, f_bwd)


x = y = z = jnp.array(1.)

print(jax.grad(f, a…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@amifalk
Comment options

@anh-tong
Comment options

Answer selected by amifalk
@amifalk
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants