Jax gradients of stateful flax operations (batchnorm) #339
Unanswered
virajmehta
asked this question in
General
Replies: 2 comments
-
never mind, figured it out, but I think that this may be an area of improvement for documentation. Thanks! |
Beta Was this translation helpful? Give feedback.
0 replies
-
Hi Viraj, I converted this issue to a conversation -- do you mind sharing how you ended up solving your issue so that others can also benefit? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I’d like to use batch norm in training normalizing flows along with other things. It seems like your implementation of batch norm (and mine) needs state which is held outside of the main training loop (the batch statistics). However, I’m running into trouble, as when I take JAX gradients flax throws
ValueError: Stateful operations are not allowed when the Collection is created outside of the current Jax transformation
I don’t see an example of Batch norm being used in a training loop rather than an inference setting as in the docs. Can you please advise on the correct workaround for this?
This also doubles as a documentation suggestion as if it wasn’t clear to me it may be unclear to others. Thanks!
Beta Was this translation helpful? Give feedback.
All reactions