Skip to content

Grad #579

Answered by jheek
Marcode10 asked this question in Q&A
Grad #579
Oct 30, 2020 · 1 comments · 4 replies
Discussion options

You must be logged in to vote

It by defualt jax.grad only computes a grad w.r.t. the first argument so you get:

def loss_fn(model, x):
  ....

model_grad = jax.grad(model, batch)

but you can also do the following:

model_grad, input_grad = jax.grad(model, batch, argnums=(0, 1))

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@Marcode10
Comment options

@jheek
Comment options

@Marcode10
Comment options

@jheek
Comment options

Answer selected by jheek
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants