v0.9.0
What's Changed
- Add NNX surgery guide by @IvyZX in #4005
- Port gemma/transformer to NNX by @copybara-service in #4019
- upgrade python to 3.10 + use pyupgrade by @cgarciae in #4038
- [nnx] add Using Filters guide by @cgarciae in #4028
- v0.8.6 by @cgarciae in #4040
- allow imagenet training profiling to be disabled in config by @copybara-service in #4043
- [nnx] LoRAParam inherits from Param by @cgarciae in #3988
- [linen] allows multiple compact methods by @cgarciae in #3808
- Added support of NANOO fp8. by @wenchenvincent in #3993
- Add functool.wraps() annotation to flax.nn.jit. by @copybara-service in #4051
- Fix typo in
nnx_basics
doc by @rajasekharporeddy in #4047 - [nnx] fix Variable overloads and add shape/dtype properties by @cgarciae in #4049
- Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4039
- [nnx] stabilize unsafe_pytree by @cgarciae in #4030
- Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4055
- [NVIDIA] Rename fp8 custom dtype to
fp32_max_grad
by @kaixih in #3984 - [nnx] fix mnist_tutorial colab link by @cgarciae in #4063
- [nnx] fix Accuracy on eager mode by @cgarciae in #4065
- Update orbax_upgrade_guide.rst for async checkpointing usage examples by @kaushaladiti-2802 in #4036
- Re-enable some tests after Python 3.9 is dropped by @IvyZX in #4067
- Rename
nnx.compat
tonnx.bridge
by @IvyZX in #4066 - [nnx] improve mnist tutorial by @cgarciae in #4070
- Modify Flax checkpointing in preparation for cl/650338576. by @copybara-service in #4072
- Remove some outdated backward-compatibility code. by @copybara-service in #4068
- [NVIDIA] Add a user guide for fp8 by @kaixih in #4076
- [nnx] add extract APIs by @cgarciae in #4078
- [example]: remove lm1b useless parallism rules by @knightXun in #4077
- [nnx] improve filters guide by @cgarciae in #4059
- [nnx] add call by @cgarciae in #4004
- Ignore Orbax warning in deprecated
flax.training.checkpoints.py
to unbreak head doctest by @IvyZX in #4092 - fix mypy failures due tu numpy update by @cgarciae in #4098
- [linen] generalize transform caching by @copybara-service in #4057
- [linen] fold rngs on jit to improve caching by @copybara-service in #4064
- Add shape-based lazy init to
LinenToNNX
(prevLinenWrapper
) by @IvyZX in #4081 - [nnx] add reseed by @cgarciae in #4099
- [nnx] add split/merge_inputs by @cgarciae in #4084
- Perform shape checks for self.param AFTER unboxing by @danielwatson6 in #4079
- fix restore_checkpoint example in docstring by @copybara-service in #4101
- [numpy] Fix users of NumPy APIs that are removed in NumPy 2.0. by @copybara-service in #4104
- set profile_duration_ms = None as in periodic_actions there's default value for both num_profile_steps and profile_duration_ms, and the profile stopping condition is when both num_profile_steps and profile_duration_ms are satisfied, so setting profile_duration_ms=None so that the passed num_profile_steps value gets used by @copybara-service in #4096
- [linen] add share_scope by @cgarciae in #4102
- Allow metadata pass-through in flax.struct.field by @cool-RR in #4056
- avoid mixing
einsum_dot_general
andeinsum
argument by specifying them explicitly in the caller. by @copybara-service in #4115 - Add logging to track deprecated codepaths. by @copybara-service in #4121
- [pmap no rank reduce cleanup]: When flipping the by @copybara-service in #4125
- Add NNXToLinen wrapper to
nnx.bridge
by @IvyZX in #4126 - Switch NNX to use Treescope instead of Penzai. by @copybara-service in #4132
- Add GroupNorm to NNX normalization layers by @treigerm in #4095
- [nnx] fix initializing propagation by @cgarciae in #4134
- add JAX-style NNX Transforms FLIP by @cgarciae in #4108
- Fix
_ParentType
annotation by @dcharatan in #4120 - add uv.lock file by @copybara-service in #4139
- use uv package manager by @cgarciae in #4136
- More testing and misc fixes on wrappers by @IvyZX in #4137
- Fix link to orbax documentation by @cool-RR in #4123
- [nnx] experimental transforms by @cgarciae in #3963
- [nnx] improve docs by @cgarciae in #4141
- remove repeated license headers by @cgarciae in #4148
- update Flax to version 0.9.0 by @copybara-service in #4147
New Contributors
- @wenchenvincent made their first contribution in #3993
- @rajasekharporeddy made their first contribution in #4047
- @kaushaladiti-2802 made their first contribution in #4036
- @knightXun made their first contribution in #4077
- @danielwatson6 made their first contribution in #4079
- @cool-RR made their first contribution in #4056
- @treigerm made their first contribution in #4095
- @dcharatan made their first contribution in #4120
Full Changelog: v0.8.5...v0.9.0