use nn.vjp to get gradients wrt inputs instead of params #2176
Answered
by
cgarciae
luweizheng
asked this question in
General
-
Hi all, I want to get gradients wrt to model input. There is a thread discussing how to get it using pure jax function. And I have already known how to do it. Now I want to use the lifted version of class FFNN(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, t, x, train: bool = True):
x = jnp.hstack((t, x))
for idx, out_feat in enumerate(self.features):
x = nn.Dense(features=out_feat)(x)
if idx != len(self.features) - 1:
x = nn.relu(x)
return x
class FBSDENN(nn.Module):
@nn.compact
def __call__(self, t, x):
mlp = FFNN(features=1 * [10] + [1])
(u, bwd) = nn.vjp(lambda mdl, t, x: mdl(t, x), mlp, t, x)
params_grad, t_grad, x_grad = bwd(jnp.ones(u.shape))
return x, x_grad There are two input (t, x) of my model. I want to get the gradient wrt x. I got the following error:
I do not want the gradient wrt params. Should I use |
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Jun 7, 2022
Replies: 1 comment 4 replies
-
Two suggestions:
(u, bwd) = nn.vjp(lambda mdl, x: mdl(t, x), mlp, x)
|
Beta Was this translation helpful? Give feedback.
4 replies
Answer selected by
luweizheng
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Two suggestions:
t
, passt
as a capture:vjp_variables='params'
by fault, try setting it toFalse
.