Feature request: functionality similar to torch.nn.utils.vector_to_parameters #2050
-
TLDR: Add jittable functionality similar to Pytorch's I have been spending some time experimenting with Bayesian Neural Networks with Jax and Flax and a big part of BNN's is the fact that one samples parameters from some distribution and subsequently uses a NN with these sampled parameters. The raw parameters are usually sampled as a 1d I have some code that is able to do this while being jittable, through some hacky tricks avoiding
With the newest version of Jax and Flax, jitting this function fails on the first line due to a Supporting jittable functionality equivalent to Pytorch's Cheers, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This API should do exactly what you want: https://jax.readthedocs.io/en/latest/_autosummary/jax.flatten_util.ravel_pytree.html |
Beta Was this translation helpful? Give feedback.
This API should do exactly what you want: https://jax.readthedocs.io/en/latest/_autosummary/jax.flatten_util.ravel_pytree.html