v0.8.5
What's Changed
- v0.8.5 by @cgarciae in #3941
- [nnx] improve vmap axis size detection by @cgarciae in #3947
- Add direct penzai.treescope support for NNX objects. by @copybara-service in #3948
- [nnx] fix nnx_basics dependencies by @cgarciae in #3942
- Rename all the NNX tests to internal naming & build conventions. by @copybara-service in #3952
- updated rng guide by @chiamp in #3912
- upgraded haiku guide to include NNX by @chiamp in #3923
- parameterized NNX transforms tests by @chiamp in #3906
- Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes. by @copybara-service in #3957
- fix HEAD by @chiamp in #3960
- Minor grammar fixes to NNX documentation. by @mcsmart76 in #3953
- Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3928
- Adding Welford metric. by @copybara-service in #3959
- Modify Welford metric to return mean value. by @copybara-service in #3970
- [nnx] make State generic by @cgarciae in #3964
- updated NNX nn docstrings by @chiamp in #3972
- make flax work with upcoming JAX change to tree_map (being more careful about by @copybara-service in #3976
- updated
nnx.module
docstrings by @chiamp in #3966 - updated
nnx.Conv
andnnx.ConvTranspose
by @chiamp in #3974 - updated
nnx.graph
docstrings by @chiamp in #3958 -
- Adds
pmap
andPmap
.static_broadcasted_argnums
,donate_argnums
, andglobal_arg_shapes
are not yet supported. by @copybara-service in #3978
- Adds
- Fixes for batch norm docs by @jkarwowski in #3982
- fix deprecation warning by @chiamp in #3981
- updated NNX
rnglib
docstring by @chiamp in #3980 - updated
nnx.training
by @chiamp in #3975 - updated
nnx.variables
docstrings by @chiamp in #3986 - [nnx] vectorize vmap split counts by @cgarciae in #3989
- added
wrt
option tonnx.Optimizer
by @chiamp in #3983 - Added
nnx.graph.iter_children
by @chiamp in #3991 - [nnx] fix vmap by @copybara-service in #3995
- Fix head pytest breakage by @IvyZX in #4006
- Helper function for loading params from a linen module by @copybara-service in #4012
- Port gemma/layers to NNX by @copybara-service in #4013
- [nnx] fix grad by @cgarciae in #4007
- [nnx] add PathContains Filter by @cgarciae in #4011
- Support Python 3.9 by @copybara-service in #4018
- Port gemma/modules to NNX by @copybara-service in #4014
- Internal change to fix current head CI by @copybara-service in #4017
- Unpin the Orbax pip version. by @copybara-service in #4024
- Fix Gemma test to unbreak head by @IvyZX in #4025
- Fix pickling of exceptions by @sanderland in #4002
- Call user-defined variable transforms before determining axis size in nn.vmap. by @copybara-service in #4026
- CI: add test run against oldest supported jax version by @jakevdp in #3996
- Make
force_fp32_for_softmax
arg inMultiHeadDotProductAttention
useful. by @copybara-service in #4029
New Contributors
- @mcsmart76 made their first contribution in #3953
- @jkarwowski made their first contribution in #3982
- @sanderland made their first contribution in #4002
Full Changelog: v0.8.4...v0.8.5