diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f8e5a5850..a9c7e77074 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,26 +4,20 @@ Changelog vNext ------ (Add your change to a random empty line to avoid merge conflicts) -- Report forward and backward pass FLOPs of modules and submodules in `linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns. - - - - -- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding -`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic` -to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389). - - -- Added `has_improved` field to EarlyStopping and changed the return signature of -`EarlyStopping.update` from returning a tuple to returning just the updated class. -See more details in [#3385](https://github.com/google/flax/pull/3385) - - -- Use new typed PRNG keys throughout flax: this essentially involved changing - uses of `jax.random.PRNGKey` to `jax.random.key`. - (See [JEP 9263](https://github.com/google/jax/pull/17297) for details). - If you notice dispatch performance regressions after this change, be sure - you update `jax` to version 0.4.16 or newer. +- +- +- +- +- +- - - - @@ -31,7 +25,22 @@ See more details in [#3385](https://github.com/google/flax/pull/3385) - - NOTE: Remember to bump version number to 0.8.0 -0.7.3 +0.7.5 +----- +- Report forward and backward pass FLOPs of modules and submodules in `linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns. +- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding +`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic` +to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389). +- Use new typed PRNG keys throughout flax: this essentially involved changing + uses of `jax.random.PRNGKey` to `jax.random.key`. + (See [JEP 9263](https://github.com/google/jax/pull/17297) for details). + If you notice dispatch performance regressions after this change, be sure + you update `jax` to version 0.4.16 or newer. +- Added `has_improved` field to EarlyStopping and changed the return signature of +`EarlyStopping.update` from returning a tuple to returning just the updated class. +See more details in [#3385](https://github.com/google/flax/pull/3385) + +0.7.4 ----- New features: - Add QK-normalization to MultiHeadDotProductAttention diff --git a/README.md b/README.md index b7ffe0f19c..0c249f9720 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ To cite this repository: author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee}, title = {{F}lax: A neural network library and ecosystem for {JAX}}, url = {http://github.com/google/flax}, - version = {0.7.4}, + version = {0.7.5}, year = {2023}, } ``` diff --git a/pyproject.toml b/pyproject.toml index 5195879442..4e8c6f03e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "numpy>=1.22", "numpy>=1.23.2; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'", - "jax>=0.4.11", + "jax>=0.4.19", "msgpack", "optax", "orbax-checkpoint",