Unclear how to add up two grads of nn.Module type #389
Unanswered
BoyuanJackChen
asked this question in
General
Replies: 1 comment
-
I believe you mean the grads are stored inside an _, params = NeRF_Model.init(random.PRNGKey(0), jnp.ones((10,10)))
model = nn.Model(NeRF_Model, params)
optimizer = flax.optim.Adam(learning_rate=0.01).create(model)
del model
# optimizer.target contains the model you could then use a tree_multimap: def dumb_loss(model, x):
return jnp.sum(model(x))
grad_fn = jax.grad(dumb_loss)
x1 = random.uniform(random.PRNGKey(0), (10, 10))
x2 = random.uniform(random.PRNGKey(1), (10, 10))
grad1 = grad_fn(optimizer.target, x1)
grad2 = grad_fn(optimizer.target, x2)
summed_grad = jax.tree_multimap(lambda x,y: x+y, grad1, grad2) in a colab, see: |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Description of the model to be implemented
I am trying to add up the grads for an optimizer. The grads has type flax.nn.Module. I wonder how to sum them up.
Dataset the model could be trained on
Image Data
Specific points to consider
/
Reference implementations in other frameworks
/
Beta Was this translation helpful? Give feedback.
All reactions