Skip to content

use nn.vjp to get gradients wrt inputs instead of params #2176

Answered by cgarciae
luweizheng asked this question in General
Discussion options

You must be logged in to vote

Two suggestions:

  1. If you don't want the jacobian wrt t, pass t as a capture:
(u, bwd) = nn.vjp(lambda mdl, x: mdl(t, x), mlp, x)
  1. vjp_variables='params' by fault, try setting it to False.

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@cgarciae
Comment options

@luweizheng
Comment options

@luweizheng
Comment options

@cgarciae
Comment options

Answer selected by luweizheng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants