v0.8.2
What's Changed
- Add +1 to version after 0.8.1 release by @IvyZX in #3684
- fixed rng guide outputs by @chiamp in #3685
- enforce mask kwarg in norm layers by @chiamp in #3663
- added kwargs to self.param and self.variable by @chiamp in #3675
- added nnx normalization tests by @chiamp in #3689
- added NNX init_cache docstring example by @chiamp in #3688
- added nnx attention equivalence test by @chiamp in #3687
- Fix bug that assumed frozen-dict keys were strings. by @copybara-service in #3692
- added nnx rmsnorm by @chiamp in #3691
- updated nnx compute_stats by @chiamp in #3693
- fixed intercept_methods docstring by @chiamp in #3694
- [nnx] Add Sphinx Docs by @cgarciae in #3678
- Fix pointless docstring example of nn.checkpoint / nn.remat. by @levskaya in #3703
- added default params rng to .apply by @chiamp in #3698
- [nnx] add partial_init by @cgarciae in #3674
- make make_rng default to 'params' by @chiamp in #3699
- Add SimpleCell. by @carlosgmartin in #3697
- fix Module.module_paths docstring by @cgarciae in #3709
- Guarantee the latest JAX version on CI by @cgarciae in #3705
- Replace deprecated API
jax.tree_map
by @copybara-service in #3715 - Use
jax.tree_util.tree_map
instead of deprecatedjax.tree_map
. by @copybara-service in #3714 - [nnx] simplify readme by @cgarciae in #3707
- [nnx] add demo.ipynb by @cgarciae in #3680
- Fix Tabulate's compute_flops by @cgarciae in #3721
- [nnx] simplify TraceState by @cgarciae in #3724
- Add broadcast of
strides
andkernel_dilation
tonn.ConvTranspose
by @IvyZX in #3731 - [nnx] Fix State.sub by @cgarciae in #3704
- [nnx] always fold_in on fork + new ForkedKeys return type by @cgarciae in #3722
- [nnx] explicit Variables by @cgarciae in #3720
- Improves fingerprint definition for Modules in nn.jit. by @copybara-service in #3736
- Flax: avoid key reuse in tests by @copybara-service in #3740
- added Einsum layer by @chiamp in #3710
- nn.jit: automatic fingerprint definition for dataclass attributes by @cgarciae in #3737
- [NVIDIA] Use custom grad accumulation for FP8 params by @kaixih in #3623
- removed nnx dataclass by @chiamp in #3742
- [nnx] cleanup graph_utils by @cgarciae in #3728
- Fix doctest and unbreak head by @IvyZX in #3753
- [nnx] add pytree support by @cgarciae in #3732
- fixed intercept_methods docstring by @chiamp in #3752
- Add ConvLSTMCell to docs. by @carlosgmartin in #3712
- [nnx] remove flagslib by @cgarciae in #3733
- Fix tests after applying JAX key-reuse checker. See: by @copybara-service in #3748
Full Changelog: v0.8.1...v0.8.2