How to use shape-dependent variables without @nn.compact in JAX/Flax? #928
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Jan 22, 2021
Replies: 1 comment
-
When using We do this in the VAE example and the WMT example. In both examples search for |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When using
setup()
In non-compact Module in Flax, one does not have access to shape information from the input directly. In order to use shape dependent variables in those modules, we either pass in the necessary shape information explicitly as construction args, or we isolate any shape-inferred variables in a submodule that we construct fromsetup()
.We do this in the VAE example and the WMT example. In both examples search for
setup(
to see how it is done.