-
Computing all of the partial derivatives in a custom gradient when only some are needed can be a costly expense. PyTorch allows users to optimize for this by passing in a boolean tuple called "needs_input_grad" (True, False, ...) denoting whether or not each parameter needs its partial derivative computed. Is there a way to support this in Jax? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I think you may need to specify Let's say that you have a function def f(x, y, z):
return x + y + z Computing gradient w.r.t # `0` is the argument position of `x`
# `2` is the argument position of `z`
jax.grad(f, argnums=(0, 2))(x, y, z)
# this a tuple containing two gradients |
Beta Was this translation helpful? Give feedback.
Oh, in this case, you can do something like this