From d998986acf1cba9e4b40bfd01bfd336ae830ebb3 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Tue, 26 Sep 2023 14:39:28 -0700 Subject: [PATCH 01/27] updated docs --- docs/api_reference/flax.cursor.rst | 2 +- docs/examples_community_examples.rst | 4 ---- docs/guides/flax_on_pjit.ipynb | 4 ++-- docs/guides/flax_on_pjit.md | 4 ++-- docs/guides/regular_dict_upgrade_guide.rst | 10 ++++++---- docs/guides/transfer_learning.ipynb | 14 +++++--------- docs/guides/transfer_learning.md | 12 ++++-------- flax/core/meta.py | 2 +- flax/linen/stochastic.py | 17 +++++++++++++---- 9 files changed, 34 insertions(+), 35 deletions(-) diff --git a/docs/api_reference/flax.cursor.rst b/docs/api_reference/flax.cursor.rst index 56ace5bc14..073f06ee1c 100644 --- a/docs/api_reference/flax.cursor.rst +++ b/docs/api_reference/flax.cursor.rst @@ -12,7 +12,7 @@ To illustrate, consider the example below:: import dataclasses from typing import Any - @dataclasses.dataclass + @dataclasses.dataclass(frozen=True) class A: x: Any diff --git a/docs/examples_community_examples.rst b/docs/examples_community_examples.rst index 0da9245c3d..079568c9a7 100644 --- a/docs/examples_community_examples.rst +++ b/docs/examples_community_examples.rst @@ -56,10 +56,6 @@ Examples - `@vasudevgupta7 `__ - Question-Answering - https://arxiv.org/abs/2007.14062 - * - `Bayesian Networks with BlackJAX `__ - - `@rlouf `__ - - Bayesian Inference, SGMCMC - - https://arxiv.org/abs/1402.4102 * - `DCGAN `__ - `@bkkaggle `__ - Image Synthesis diff --git a/docs/guides/flax_on_pjit.ipynb b/docs/guides/flax_on_pjit.ipynb index 6ec633772a..d7705f34f4 100644 --- a/docs/guides/flax_on_pjit.ipynb +++ b/docs/guides/flax_on_pjit.ipynb @@ -190,7 +190,7 @@ "\n", "1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters.\n", "\n", - "2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n", + "2. Apply [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n", "\n", " * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless." ] @@ -1281,7 +1281,7 @@ "\n", "* **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", - "* **logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n", + "* **Logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n", "\n", "* **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful." ] diff --git a/docs/guides/flax_on_pjit.md b/docs/guides/flax_on_pjit.md index 3b53875522..81e7b4d623 100644 --- a/docs/guides/flax_on_pjit.md +++ b/docs/guides/flax_on_pjit.md @@ -119,7 +119,7 @@ To shard the parameters efficiently, apply the following APIs to annotate the pa 1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters. -2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known. +2. Apply [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known. * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless. @@ -551,7 +551,7 @@ Choosing when to use a device or logical axis depends on how much you want to co * **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming. -* **logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model. +* **Logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model. * **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful. diff --git a/docs/guides/regular_dict_upgrade_guide.rst b/docs/guides/regular_dict_upgrade_guide.rst index 8659ad5a3e..dda795f0de 100644 --- a/docs/guides/regular_dict_upgrade_guide.rst +++ b/docs/guides/regular_dict_upgrade_guide.rst @@ -120,9 +120,11 @@ Alternatively, the environment variable ``flax_return_frozendict`` (found `here `__) can be directly modified in the Flax source code. -Migration plan +Migration status -------------- -Currently ``flax_return_frozendict`` is set to True, meaning Flax will default to returning ``FrozenDicts``. -In the future this flag will be flipped to False, and Flax will instead default to returning regular dicts. -Eventually this feature flag will be removed once the migration is complete. \ No newline at end of file +As of July 19th, 2023, ``flax_return_frozendict`` is set to ``False`` (see +`#3193 `__), meaning Flax will default to +returning regular dicts from version `0.7.1 `__ +onward. This flag can be flipped to ``True`` temporarily to have Flax return +``Frozendicts``. However this feature flag will eventually be removed in the future. \ No newline at end of file diff --git a/docs/guides/transfer_learning.ipynb b/docs/guides/transfer_learning.ipynb index d2bd152304..78b0b8ff7f 100644 --- a/docs/guides/transfer_learning.ipynb +++ b/docs/guides/transfer_learning.ipynb @@ -165,7 +165,7 @@ "metadata": {}, "source": [ "## Transfering the parameters\n", - "Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location. This can be done by unfreezing `params`, updating the `backbone` parameters, and freezing the `params` again:" + "Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location (i.e. the `backbone`):" ] }, { @@ -174,11 +174,7 @@ "metadata": {}, "outputs": [], "source": [ - "import flax\n", - "\n", - "params = flax.core.unfreeze(params)\n", - "params['backbone'] = vision_model_vars['params']\n", - "params = flax.core.freeze(params)" + "params['backbone'] = vision_model_vars['params']" ] }, { @@ -247,13 +243,13 @@ "import optax\n", "\n", "partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}\n", - "param_partitions = flax.core.freeze(traverse_util.path_aware_map(\n", - " lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params))\n", + "param_partitions = traverse_util.path_aware_map(\n", + " lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)\n", "tx = optax.multi_transform(partition_optimizers, param_partitions)\n", "\n", "# visualize a subset of the param_partitions structure\n", "flat = list(traverse_util.flatten_dict(param_partitions).items())\n", - "flax.core.freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))" + "traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))" ] }, { diff --git a/docs/guides/transfer_learning.md b/docs/guides/transfer_learning.md index d467139267..8a563d498d 100644 --- a/docs/guides/transfer_learning.md +++ b/docs/guides/transfer_learning.md @@ -111,14 +111,10 @@ params = variables['params'] ``` ## Transfering the parameters -Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location. This can be done by unfreezing `params`, updating the `backbone` parameters, and freezing the `params` again: +Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location (i.e. the `backbone`): ```{code-cell} ipython3 -import flax - -params = flax.core.unfreeze(params) params['backbone'] = vision_model_vars['params'] -params = flax.core.freeze(params) ``` **Note:** if the model contains other variable collections such as `batch_stats`, these have to be transfered as well. @@ -153,13 +149,13 @@ from flax import traverse_util import optax partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()} -param_partitions = flax.core.freeze(traverse_util.path_aware_map( - lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)) +param_partitions = traverse_util.path_aware_map( + lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params) tx = optax.multi_transform(partition_optimizers, param_partitions) # visualize a subset of the param_partitions structure flat = list(traverse_util.flatten_dict(param_partitions).items()) -flax.core.freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))) +traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])) ``` To implement [differential learning rates](https://blog.slavv.com/differential-learning-rates-59eff5209a4f), the `optax.set_to_zero` can be replaced with any other optimizer, different optimizers and partitioning schemes can be selected depending on the task. For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation. diff --git a/flax/core/meta.py b/flax/core/meta.py index b21f2a2ee6..ec030450c2 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -206,7 +206,7 @@ def __call__(self, x): mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. - # this way we can determinte the PartitionSpecs for the init variables + # this way we can determine the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) diff --git a/flax/linen/stochastic.py b/flax/linen/stochastic.py index 9845f3c3d0..cf90a7c5c5 100644 --- a/flax/linen/stochastic.py +++ b/flax/linen/stochastic.py @@ -32,10 +32,19 @@ class Dropout(Module): """Create a dropout layer. Note: When using :meth:`Module.apply() `, make sure - to include an RNG seed named `'dropout'`. For example:: - - model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': - dropout_rng})` + to include an RNG seed named `'dropout'`. Dropout isn't necessary for + variable initialization. Example:: + + class MLP(nn.Module): + @nn.compact + def __call__(self, x, train): + x = nn.Dense(4)(x) + x = nn.Dropout(0.5, deterministic=not train)(x) + return x + model = MLP() + x = jnp.ones((1, 3)) + variables = model.init(jax.random.key(0), x, train=False) # don't use dropout + model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout Attributes: rate: the dropout probability. (_not_ the keep rate!) From d0216d31c4d822330a754a8bd98d97064277abe3 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 9 Oct 2023 20:29:11 +0000 Subject: [PATCH 02/27] Add Flax F.A.Q. --- docs/faq.rst | 38 ++++++++++++++++++++++++++++++++++++++ docs/index.rst | 2 ++ 2 files changed, 40 insertions(+) create mode 100644 docs/faq.rst diff --git a/docs/faq.rst b/docs/faq.rst new file mode 100644 index 0000000000..57760315e4 --- /dev/null +++ b/docs/faq.rst @@ -0,0 +1,38 @@ +Frequently Asked Questions (FAQ) +================================ + +This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in `GitHub Discussions `__. + +Where to search for an answer to a Flax-related question +******************************************************** + +There are a number of official Flax resources to search for information: + +- `Flax Documentation on ReadTheDocs `__ (this site): Use the `search bar `__ or the table of contents on the left-hand side. +- `google/flax GitHub Discussions `__: Search for an existing topic or start a new one. If you can't find what you're looking for, feel free to ask the Flax team or community a question. +- `google/flax GitHub Issues `__: Use the search bar to look for an existing issue or a feature request, or start a new one. + +How to take the derivative with respect to an intermediate value (using :code:`Module.perturb`) +*********************************************************************************************** + +To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use :meth:`flax.linen.Module.perturb`. You define a zero-value :class:`flax.linen.Module` "perturbation" parameter – :code:`perturb(...)` – in the forward pass with the same shape as the intermediate activation, define the loss function with :code:`'perturbations'` as an added standalone argument, perform a JAX derivative operation with :code:`jax.grad` on the perturbation argument. + +For full examples and detailed documentation, go to: + +- The :meth:`flax.linen.Module.perturb` API docs +- The `Extracting gradients of intermediate values `_ guide +- `Flax GitHub Discussions #1152 `__ + +Is Flax Linen :code:`remat_scan()` the same as :code:`scan(remat(...))`? +************************************************************************ + +Flax :code:`remat_scan()` (:meth:`flax.linen.remat_scan()`) and :code:`scan(remat(...))` (:meth:`flax.linen.scan` over :meth:`flax.linen.remat`) are not the same, and :code:`remat_scan()` is limited in cases it supports. Namely, :code:`remat_scan()` treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use :code:`scan(remat(...))`, as typically you would need the extra parameters, such as ``in_axes`` (for input array axes) or ``out_axes`` (output array axes), which :meth:`flax.linen.remat_scan` does not expose. + +What are the recommended training loop libraries? +************************************************* + +Consider using CLU (Common Loop Utils) `google/CommonLoopUtils `__. To get started, go to this `CLU Synopsis Colab `__. You can find answers to common questions about CLU with Flax on `google/flax GitHub Discussions `__. + +Check out the official `google/flax Examples `__ for examples of using the training loop with (CLU) metrics. For example, this is `Flax ImageNet's train.py `__. + +For computer vision research, consider `google-research/scenic `__. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go the `README page on GitHub `__. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index cc4793ba6f..d1f0bc8cbb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -318,7 +318,9 @@ Notable examples in Flax include: guides/index examples glossary + faq developer_notes/index philosophy contributing api_reference/index + \ No newline at end of file From 7b1d8a61edf8db7a81db0dbdf0e3ab6c18c77028 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 9 Oct 2023 21:15:11 +0000 Subject: [PATCH 03/27] Update Flax F.A.Q. --- docs/faq.rst | 10 +++++----- docs/index.rst | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/docs/faq.rst b/docs/faq.rst index 57760315e4..f6b0d30b16 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -3,8 +3,8 @@ Frequently Asked Questions (FAQ) This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in `GitHub Discussions `__. -Where to search for an answer to a Flax-related question -******************************************************** +Where to search for an answer to a Flax-related question? +********************************************************* There are a number of official Flax resources to search for information: @@ -12,8 +12,8 @@ There are a number of official Flax resources to search for information: - `google/flax GitHub Discussions `__: Search for an existing topic or start a new one. If you can't find what you're looking for, feel free to ask the Flax team or community a question. - `google/flax GitHub Issues `__: Use the search bar to look for an existing issue or a feature request, or start a new one. -How to take the derivative with respect to an intermediate value (using :code:`Module.perturb`) -*********************************************************************************************** +How to take the derivative with respect to an intermediate value (using :code:`Module.perturb`)? +************************************************************************************************ To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use :meth:`flax.linen.Module.perturb`. You define a zero-value :class:`flax.linen.Module` "perturbation" parameter – :code:`perturb(...)` – in the forward pass with the same shape as the intermediate activation, define the loss function with :code:`'perturbations'` as an added standalone argument, perform a JAX derivative operation with :code:`jax.grad` on the perturbation argument. @@ -35,4 +35,4 @@ Consider using CLU (Common Loop Utils) `google/CommonLoopUtils `__ for examples of using the training loop with (CLU) metrics. For example, this is `Flax ImageNet's train.py `__. -For computer vision research, consider `google-research/scenic `__. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go the `README page on GitHub `__. \ No newline at end of file +For computer vision research, consider `google-research/scenic `__. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the `README page on GitHub `__. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index d1f0bc8cbb..fd79ddb607 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -322,5 +322,4 @@ Notable examples in Flax include: developer_notes/index philosophy contributing - api_reference/index - \ No newline at end of file + api_reference/index \ No newline at end of file From 35326d4889e3d0aeb8837696b65152cca8a948e6 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Mon, 9 Oct 2023 17:54:15 -0700 Subject: [PATCH 04/27] added rmsnorm to api docs --- docs/api_reference/flax.linen/layers.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/api_reference/flax.linen/layers.rst b/docs/api_reference/flax.linen/layers.rst index 87bfff0fc9..792fa90ac1 100644 --- a/docs/api_reference/flax.linen/layers.rst +++ b/docs/api_reference/flax.linen/layers.rst @@ -52,6 +52,10 @@ Normalization :module: flax.linen :class: GroupNorm +.. flax_module:: + :module: flax.linen + :class: RMSNorm + .. flax_module:: :module: flax.linen :class: SpectralNorm From b76f4878b56d990299b52bebd9a570b29599d8a7 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Wed, 11 Oct 2023 13:32:40 -0700 Subject: [PATCH 05/27] Copybara import of the project: -- f6a222c710f87efbc3c1132af5deb853fb207ad9 by Marcus Chiam : split inputs_kv arg in attention layer COPYBARA_INTEGRATE_REVIEW=https://github.com/google/flax/pull/3379 from chiamp:attention f6a222c710f87efbc3c1132af5deb853fb207ad9 PiperOrigin-RevId: 572671273 --- examples/wmt/models.py | 2 +- flax/linen/attention.py | 90 +++++++++++++++++++++++++-- pyproject.toml | 6 ++ tests/linen/linen_attention_test.py | 95 +++++++++++++++++++++++++---- 4 files changed, 173 insertions(+), 20 deletions(-) diff --git a/examples/wmt/models.py b/examples/wmt/models.py index 0f2fd2f962..6ed08ccd23 100644 --- a/examples/wmt/models.py +++ b/examples/wmt/models.py @@ -299,7 +299,7 @@ def __call__( broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, - )(y, encoded, encoder_decoder_mask) + )(y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=config.dropout_rate)( y, deterministic=config.deterministic diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 851434f15a..575620efcd 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,7 +15,8 @@ """Attention core modules for Flax.""" import functools -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union, overload +import warnings from flax.linen import initializers from flax.linen.dtypes import promote_dtype @@ -248,11 +249,37 @@ class MultiHeadDotProductAttention(Module): qkv_dot_general_cls: Any = None out_dot_general_cls: Any = None + @overload + def __call__( + self, + inputs_q: Array, + inputs_k: Optional[Array] = None, + inputs_v: Optional[Array] = None, + *, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): + ... + + @overload + def __call__( + self, + inputs_q: Array, + *, + inputs_kv: Array = None, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): + ... + @compact def __call__( self, inputs_q: Array, - inputs_kv: Array, + inputs_k: Optional[Array] = None, + inputs_v: Optional[Array] = None, + *, + inputs_kv: Optional[Array] = None, mask: Optional[Array] = None, deterministic: Optional[bool] = None, ): @@ -261,9 +288,19 @@ def __call__( Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. + If both inputs_k and inputs_v are None, they will both copy the value of + inputs_q (self attention). + If only inputs_v is None, it will copy the value of inputs_k. + Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. - inputs_kv: key/values of shape `[batch_sizes..., length, features]`. + inputs_k: key of shape `[batch_sizes..., length, features]`. If None, + inputs_k will copy the value of inputs_q. + inputs_v: values of shape `[batch_sizes..., length, features]`. If None, + inputs_v will copy the value of inputs_k. + inputs_kv: key/values of shape `[batch_sizes..., length, features]`. If + None, inputs_kv will copy the value of inputs_q. This arg will be + deprecated soon. Use inputs_k and inputs_v instead. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Attention weights are masked out if their corresponding mask value is `False`. @@ -273,6 +310,42 @@ def __call__( Returns: output of shape `[batch_sizes..., length, features]`. """ + if inputs_kv is not None: + if inputs_k is not None or inputs_v is not None: + raise ValueError('If either `inputs_k` or `inputs_v` is not None, ' + '`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` ' + 'and `inputs_v` must be None. We recommend using `inputs_k` and ' + '`inputs_v` args, since `inputs_kv` will be deprecated soon. See ' + 'https://github.com/google/flax/discussions/3389 for more ' + 'information.') + inputs_k = inputs_v = inputs_kv + warnings.warn('The inputs_kv arg will be deprecated soon. ' + 'Use inputs_k and inputs_v instead. See ' + 'https://github.com/google/flax/discussions/3389 ' + 'for more information.', + DeprecationWarning) + else: + if inputs_k is None: + if inputs_v is not None: + raise ValueError('`inputs_k` cannot be None if `inputs_v` is not None. ' + 'To have both `inputs_k` and `inputs_v` be the same value, pass in the ' + 'value to `inputs_k` and leave `inputs_v` as None.') + inputs_k = inputs_q + if inputs_v is None: + inputs_v = inputs_k + elif inputs_v.shape[-1] == inputs_v.shape[-2]: + warnings.warn(f"You are passing an array of shape {inputs_v.shape} " + "to the `inputs_v` arg, when you may have intended " + "to pass it to the `mask` arg. As of Flax version " + "0.7.4, the function signature of " + "MultiHeadDotProductAttention's `__call__` method " + "has changed to `__call__(inputs_q, inputs_k=None, " + "inputs_v=None, *, inputs_kv=None, mask=None, " + "deterministic=None)`. Use the kwarg `mask` instead. " + "See https://github.com/google/flax/discussions/3389 " + "and read the docstring for more information.", + DeprecationWarning) + features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( @@ -298,8 +371,8 @@ def __call__( # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = ( dense(name='query')(inputs_q), - dense(name='key')(inputs_kv), - dense(name='value')(inputs_kv), + dense(name='key')(inputs_k), + dense(name='value')(inputs_v), ) if self.normalize_qk: @@ -429,8 +502,13 @@ def __call__( # type: ignore Returns: output of shape `[batch_sizes..., length, features]`. """ + warnings.warn('SelfAttention will be deprecated soon. Use ' + '`MultiHeadDotProductAttention.__call__(inputs_q)` instead. ' + 'See https://github.com/google/flax/discussions/3389 ' + 'for more information.', + DeprecationWarning) return super().__call__( - inputs_q, inputs_q, mask, deterministic=deterministic + inputs_q, mask=mask, deterministic=deterministic ) diff --git a/pyproject.toml b/pyproject.toml index fd3bb166e3..41d1ea4cde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,12 @@ filterwarnings = [ "ignore:.*module 'sre_constants' is deprecated.*:DeprecationWarning", # DeprecationWarning: jax.random.KeyArray is deprecated. "ignore:.*jax.random.KeyArray is deprecated.*:DeprecationWarning", + # DeprecationWarning: SelfAttention will be deprecated soon. + "ignore:.*SelfAttention will be deprecated soon.*:DeprecationWarning", + # DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead. + "ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning", + # DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed + "ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning" # DeprecationWarning: ml_dtypes.float8_e4m3b11 is deprecated. "ignore:.*ml_dtypes.float8_e4m3b11 is deprecated.*:DeprecationWarning", ] diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 2556b292c2..aba47516e4 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized +from flax import errors from flax import linen as nn from flax import jax_utils from flax.core import pop @@ -67,7 +68,6 @@ def test_dtype_infer(self): def test_multihead_encoder_decoder_attention(self): rng = random.key(0) q = jnp.ones((4, 2, 3, 5)) - kv = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, @@ -75,7 +75,7 @@ def test_multihead_encoder_decoder_attention(self): bias_init=initializers.zeros, deterministic=False, ) - y, _ = sa_module.init_with_output(rng, q, kv) + y, _ = sa_module.init_with_output(rng, q) self.assertEqual(y.shape, q.shape) def test_multihead_self_attention_w_dropout(self): @@ -91,7 +91,7 @@ def test_multihead_self_attention_w_dropout(self): ) rng1, rng2 = random.split(rng) rngs = {'params': rng1, 'dropout': rng2} - y, _ = sa_module.init_with_output(rngs, x, x) + y, _ = sa_module.init_with_output(rngs, x) self.assertEqual(y.shape, x.shape) def test_multihead_self_attention_w_dropout_disabled(self): @@ -108,11 +108,11 @@ def test_multihead_self_attention_w_dropout_disabled(self): rng1, rng2, rng3, rng4 = random.split(rng, 4) rngs1 = {'params': rng1, 'dropout': rng2} rngs2 = {'params': rng3, 'dropout': rng4} - y1, vs = sa_module0.init_with_output(rngs1, x, x) - y2, _ = sa_module0.init_with_output(rngs2, x, x) + y1, vs = sa_module0.init_with_output(rngs1, x) + y2, _ = sa_module0.init_with_output(rngs2, x) np.testing.assert_allclose(y1, y2) - y3 = sa_module0.apply(vs, x, x, rngs=rngs1) - y4 = sa_module0.apply(vs, x, x, rngs=rngs2) + y3 = sa_module0.apply(vs, x, rngs=rngs1) + y4 = sa_module0.apply(vs, x, rngs=rngs2) np.testing.assert_allclose(y3, y4) sa_module1 = nn.MultiHeadDotProductAttention( num_heads=8, @@ -121,8 +121,8 @@ def test_multihead_self_attention_w_dropout_disabled(self): bias_init=initializers.zeros, dropout_rate=0.0, ) - y5 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs1) - y6 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs2) + y5 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs1) + y6 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs2) np.testing.assert_allclose(y5, y6) sa_module2 = nn.MultiHeadDotProductAttention( num_heads=8, @@ -131,8 +131,8 @@ def test_multihead_self_attention_w_dropout_disabled(self): bias_init=initializers.zeros, dropout_rate=0.5, ) - y7 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs1) - y8 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs2) + y7 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs1) + y8 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs2) np.testing.assert_allclose(y7, y8) def test_causal_mask_1d(self): @@ -204,11 +204,11 @@ def test_autoregresive_receptive_field_1d(self): deterministic=False, ) - initial_vars = module.init(rng1, inputs, inputs) + initial_vars = module.init(rng1, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) def model_loss(inputs, pos): - out = module.apply(initial_vars, inputs, inputs, causal_mask) + out = module.apply(initial_vars, inputs, mask=causal_mask) assert out.shape == input_shape assert len(out.shape) == 3 return out[0, pos, :].sum() @@ -234,6 +234,75 @@ def get_receptive_field_1d(pos): 'autoregressive self-attention.' ) + def test_multihead_self_attention_equality(self): + rng = random.key(0) + q = jnp.ones((4, 2, 3, 5)) + module_kwargs = {'num_heads': 8, + 'qkv_features': 16, + 'kernel_init': initializers.ones, + 'bias_init': initializers.zeros, + 'deterministic': False} + sa_module0 = nn.MultiHeadDotProductAttention(**module_kwargs) + sa_module1 = nn.SelfAttention(**module_kwargs) + y0, v0 = sa_module0.init_with_output(rng, q) + with self.assertWarnsRegex(DeprecationWarning, 'SelfAttention will be deprecated soon.'): + y1, v1 = sa_module1.init_with_output(rng, q) + self.assertTrue((y0 == y1).all()) + self.assertTrue(jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v0, v1))) + + def test_multihead_kv_args(self): + key1, key2 = random.split(random.key(0), 2) + query = random.uniform(key1, (3, 5)) + key_value = random.uniform(key1, (9, 5)) + module = nn.MultiHeadDotProductAttention( + num_heads=8, + qkv_features=16, + kernel_init=initializers.ones, + bias_init=initializers.zeros, + deterministic=False, + ) + y0, v0 = module.init_with_output(key2, query, inputs_k=key_value, inputs_v=key_value) + y1, v1 = module.init_with_output(key2, query, inputs_k=key_value) + with self.assertWarnsRegex(DeprecationWarning, 'The inputs_kv arg will be deprecated soon.'): + y2, v2 = module.init_with_output(key2, query, inputs_kv=key_value) + self.assertTrue((y0 == y1).all() and (y1 == y2).all()) + self.assertTrue( + jax.tree_util.tree_all( + jax.tree_map(lambda x, y, z: (x == y).all() and (y == z).all(), + v0, v1, v2))) + + with self.assertRaisesRegex(ValueError, '`inputs_k` cannot be None if `inputs_v` is not None.'): + y3, v3 = module.init_with_output(key2, query, inputs_v=key_value) + with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'): + y3, v3 = module.init_with_output(key2, query, inputs_kv=key_value, inputs_v=key_value) + with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'): + y3, v3 = module.init_with_output(key2, query, key_value, key_value, inputs_kv=key_value) + + def test_multihead_mask_warning(self): + rng = random.key(0) + rng1, rng2 = random.split(rng, num=2) + + length = 10 + dim = 1 + num_heads = 1 + input_shape = (1, length, dim) + query = key = random.normal(rng2, input_shape) + + module = nn.MultiHeadDotProductAttention( + num_heads=num_heads, + kernel_init=jax.nn.initializers.ones, + deterministic=False, + ) + + initial_vars = module.init(rng1, query, key) + causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) + + module.apply(initial_vars, query, key, mask=causal_mask) + with self.assertWarnsRegex(DeprecationWarning, + "the function signature of MultiHeadDotProductAttention's `__call__` method has changed"): + with self.assertRaises(errors.ScopeParamShapeError): + module.apply(initial_vars, query, key, causal_mask) + if __name__ == '__main__': absltest.main() From 1a0961e4ecddd7762773be68dd17ce5cd458645c Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Wed, 11 Oct 2023 14:37:27 -0700 Subject: [PATCH 06/27] fix HEAD --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 41d1ea4cde..2b60ce6375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,7 +152,7 @@ filterwarnings = [ # DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead. "ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning", # DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed - "ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning" + "ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning", # DeprecationWarning: ml_dtypes.float8_e4m3b11 is deprecated. "ignore:.*ml_dtypes.float8_e4m3b11 is deprecated.*:DeprecationWarning", ] From 59054a295d3893c553442b99df335d4ac9f47eee Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Fri, 6 Oct 2023 15:03:20 -0700 Subject: [PATCH 07/27] added weightnorm layer --- docs/api_reference/flax.linen/layers.rst | 5 + flax/linen/__init__.py | 1 + flax/linen/normalization.py | 166 ++++++++++++++++++++++- tests/linen/linen_test.py | 157 +++++++++++++++++++++ 4 files changed, 328 insertions(+), 1 deletion(-) diff --git a/docs/api_reference/flax.linen/layers.rst b/docs/api_reference/flax.linen/layers.rst index 792fa90ac1..d4f28b21c7 100644 --- a/docs/api_reference/flax.linen/layers.rst +++ b/docs/api_reference/flax.linen/layers.rst @@ -60,6 +60,10 @@ Normalization :module: flax.linen :class: SpectralNorm +.. flax_module:: + :module: flax.linen + :class: WeightNorm + Combinators ------------------------ @@ -136,6 +140,7 @@ Recurrent GroupNorm RMSNorm SpectralNorm + WeightNorm Sequential Dropout SelfAttention diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 6df28ca67b..8d45daed02 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -108,6 +108,7 @@ LayerNorm as LayerNorm, RMSNorm as RMSNorm, SpectralNorm as SpectralNorm, + WeightNorm as WeightNorm ) from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool) from .recurrent import ( diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 899351962f..8155ea8c1a 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -14,8 +14,8 @@ """Normalization modules for Flax.""" -import dataclasses import functools +from dataclasses import field from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union from flax.linen.dtypes import canonicalize_dtype @@ -849,3 +849,167 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): dtype = canonicalize_dtype(vs, u0, v0, sigma, dtype=self.dtype) return jnp.asarray(value_bar, dtype) + +class WeightNorm(Module): + """L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf). + + Weight normalization normalizes the weight params so that the l2-norm of + the matrix is equal to 1. This is implemented as a layer wrapper where + each wrapped layer will have its params l2-normalized before computing + its ``__call__`` output. + + Example:: + + class Baz(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(2)(x) + + class Bar(nn.Module): + @nn.compact + def __call__(self, x): + x = Baz()(x) + x = nn.Dense(3)(x) + x = Baz()(x) + x = nn.Dense(3)(x) + return x + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(3)(x) + # l2-normalize all params of the second Dense layer + x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x) + x = nn.Dense(5)(x) + # l2-normalize all kernels in the Bar submodule and all params in the Baz submodule + x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x) + return x + + # init + x = jnp.ones((1, 2)) + model = Foo() + variables = model.init(jax.random.key(0), x) + + variables + # { + # params: { + # ... + # WeightNorm_0: { + # Dense_1/bias/scale: Array([1., 1., 1., 1.], dtype=float32), + # Dense_1/kernel/scale: Array([1., 1., 1., 1.], dtype=float32), + # }, + # ... + # WeightNorm_1: { + # Bar_0/Baz_0/Dense_0/bias/scale: Array([1., 1.], dtype=float32), + # Bar_0/Baz_0/Dense_0/kernel/scale: Array([1., 1.], dtype=float32), + # Bar_0/Baz_1/Dense_0/bias/scale: Array([1., 1.], dtype=float32), + # Bar_0/Baz_1/Dense_0/kernel/scale: Array([1., 1.], dtype=float32), + # Bar_0/Dense_0/kernel/scale: Array([1., 1., 1.], dtype=float32), + # Bar_0/Dense_1/kernel/scale: Array([1., 1., 1.], dtype=float32), + # }, + # ... + # } + # } + + Attributes: + layer_instance: Module instance that is wrapped with WeightNorm + epsilon: A small float added to l2-normalization to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_scale: If True, creates a learnable variable ``scale`` that is + multiplied to the ``layer_instance`` variables after l2-normalization. + scale_init: Initialization function for the scaling function. + feature_axes: The feature axes dimension(s). The l2-norm is calculated by + reducing the ``layer_instance`` variables over the remaining (non-feature) + axes. Therefore a separate l2-norm value is calculated and a separate + scale (if ``use_scale=True``) is learned for each specified feature. By + default, the trailing dimension is treated as the feature axis. + variable_filter: An optional iterable that contains string items. The + WeightNorm layer will selectively apply l2-normalization to the + ``layer_instance`` variables whose key path (delimited by '/') has a + match with ``variable_filter``. For example, ``variable_filter={'kernel'}`` + will only apply l2-normalization to variables whose key path contains + 'kernel'. By default, ``variable_filter={'kernel'}``. + """ + + layer_instance: Module + epsilon: float = 1e-12 + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + use_scale: bool = True + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + feature_axes: Optional[Axes] = -1 + variable_filter: Optional[Iterable] = field(default_factory=lambda: {'kernel'}) + + @compact + def __call__(self, *args, **kwargs): + """Compute the l2-norm of the weights in ``self.layer_instance`` + and normalize the weights using this value before computing the + ``__call__`` output. + + Args: + *args: positional arguments to be passed into the call method of the + underlying layer instance in ``self.layer_instance``. + **kwargs: keyword arguments to be passed into the call method of the + underlying layer instance in ``self.layer_instance``. + + Returns: + Output of the layer using l2-normalized weights. + """ + + def layer_forward(layer_instance): + return layer_instance(*args, **kwargs) + + return map_variables( + layer_forward, + trans_in_fn=lambda vs: jax.tree_util.tree_map_with_path( + self._l2_normalize, + vs, + ), + init=self.is_initializing(), + )(self.layer_instance) + + def _l2_normalize(self, path, vs): + """Compute the l2-norm and normalize the variables ``vs`` using this + value. This is intended to be a helper function used in this Module's + ``__call__`` method in conjunction with ``nn.transforms.map_variables`` + and ``jax.tree_util.tree_map_with_path``. + + Args: + path: dict key path, used for naming the ``scale`` variable + vs: variables to be l2-normalized + """ + value = jnp.asarray(vs) + str_path = self.layer_instance.name + '/' + '/'.join((dict_key.key for dict_key in path[1:])) + if self.variable_filter: + for variable_name in self.variable_filter: + if variable_name in str_path: + break + else: + return value + + if self.feature_axes is None: + feature_axes = () + reduction_axes = tuple(i for i in range(value.ndim)) + else: + feature_axes = _canonicalize_axes(value.ndim, self.feature_axes) + reduction_axes = tuple(i for i in range(value.ndim) if i not in feature_axes) + + feature_shape = [1] * value.ndim + reduced_feature_shape = [] + for ax in feature_axes: + feature_shape[ax] = value.shape[ax] + reduced_feature_shape.append(value.shape[ax]) + + value_bar = _l2_normalize(value, axis=reduction_axes, eps=self.epsilon) + + args = [vs] + if self.use_scale: + scale = self.param( + str_path + '/scale', self.scale_init, reduced_feature_shape, self.param_dtype + ).reshape(feature_shape) + value_bar *= scale + args.append(scale) + + dtype = canonicalize_dtype(*args, dtype=self.dtype) + return jnp.asarray(value_bar, dtype) \ No newline at end of file diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 4d8ea322ad..5a79e29ee7 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -455,6 +455,163 @@ def __call__(self, x, train): else: variables = model_cls.init(random.PRNGKey(0), x, train=False) + @parameterized.parameters( + {'feature_axes': -1, 'reduction_axes': 0, 'variable_filter': {'kernel'}}, + {'feature_axes': 0, 'reduction_axes': 1, 'variable_filter': {'kernel'}}, + {'feature_axes': (0, 1), 'reduction_axes': (), 'variable_filter': {'kernel'}}, + {'feature_axes': (), 'reduction_axes': (0, 1), 'variable_filter': {'kernel'}}, + {'feature_axes': None, 'reduction_axes': (0, 1), 'variable_filter': {'kernel'}}, + {'feature_axes': 0, 'reduction_axes': (), 'variable_filter': {'bias'}}, + {'feature_axes': (), 'reduction_axes': -1, 'variable_filter': {'bias'}} + ) + def test_manual_weight_norm(self, feature_axes, reduction_axes, variable_filter): + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + return nn.WeightNorm(nn.Dense(2, bias_init=nn.initializers.normal()), + feature_axes=feature_axes, + variable_filter=variable_filter)(x) + key1, key2 = jax.random.split(jax.random.key(1)) + x = jax.random.normal(key1, (1, 3)) + module = Foo() + v = module.init(key2, x) + v = jax.tree_map(lambda x: x + 0.5, v) + out = module.apply(v, x) + + kernel = v['params']['Dense_0']['kernel'] + if 'kernel' in variable_filter: + kernel /= jnp.sqrt(jnp.sum(kernel**2, axis=reduction_axes, keepdims=True)) + kernel_scale = jnp.expand_dims(v['params']['WeightNorm_0']['Dense_0/kernel/scale'], axis=reduction_axes) + else: + kernel_scale = 1 + bias = v['params']['Dense_0']['bias'] + if 'bias' in variable_filter: + bias /= jnp.sqrt(jnp.sum(bias**2, axis=reduction_axes, keepdims=True)) + bias_scale = jnp.expand_dims(v['params']['WeightNorm_0']['Dense_0/bias/scale'], axis=reduction_axes) + else: + bias_scale = 1 + manual_out = jnp.dot(x, kernel_scale * kernel) + (bias_scale * bias).reshape(1, -1) + + self.assertTrue(jnp.allclose(out, manual_out)) + + @parameterized.parameters( + {'variable_filters': ({}, None, {'kernel', 'bias'}, {'Bar'}), + 'key_paths': {'Bar_0/Baz_0/Dense_0/kernel/scale', + 'Bar_0/Baz_0/Dense_0/bias/scale', + 'Bar_0/Dense_0/kernel/scale', + 'Bar_0/Dense_0/bias/scale', + 'Bar_0/Baz_1/Dense_0/kernel/scale', + 'Bar_0/Baz_1/Dense_0/bias/scale', + 'Bar_0/Dense_1/kernel/scale', + 'Bar_0/Dense_1/bias/scale'}}, + {'variable_filters': ({'kernel'},), + 'key_paths': {'Bar_0/Baz_0/Dense_0/kernel/scale', + 'Bar_0/Dense_0/kernel/scale', + 'Bar_0/Baz_1/Dense_0/kernel/scale', + 'Bar_0/Dense_1/kernel/scale'}}, + {'variable_filters': ({'Baz', 'kernel'},), + 'key_paths': {'Bar_0/Baz_0/Dense_0/kernel/scale', + 'Bar_0/Baz_0/Dense_0/bias/scale', + 'Bar_0/Dense_0/kernel/scale', + 'Bar_0/Baz_1/Dense_0/kernel/scale', + 'Bar_0/Baz_1/Dense_0/bias/scale', + 'Bar_0/Dense_1/kernel/scale'}} + ) + def test_weight_norm_variable_filter(self, variable_filters, key_paths): + class Baz(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(2)(x) + class Bar(nn.Module): + @nn.compact + def __call__(self, x): + x = Baz()(x) + x = nn.Dense(3)(x) + x = Baz()(x) + x = nn.Dense(3)(x) + return x + + for variable_filter in variable_filters: + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + return nn.WeightNorm(Bar(), variable_filter=variable_filter)(x) + v = Foo().init(jax.random.key(0), jnp.ones((1, 4))) + self.assertEqual(key_paths, v['params']['WeightNorm_0'].keys()) + + @parameterized.parameters( + {'model_index': 0, 'key_paths': {'Dense_1/kernel/scale'}}, + {'model_index': 1, 'key_paths': {'Conv_0/kernel/scale'}}, + {'model_index': 2, 'key_paths': {'MultiHeadDotProductAttention_0/key/kernel/scale', + 'MultiHeadDotProductAttention_0/out/kernel/scale', + 'MultiHeadDotProductAttention_0/query/kernel/scale', + 'MultiHeadDotProductAttention_0/value/kernel/scale'}} + ) + def test_weight_norm_train( + self, model_index, key_paths + ): + class FooDense(nn.Module): + @nn.compact + def __call__(self, x,): + x = nn.Dense(8)(x) + x = nn.WeightNorm(nn.Dense(6))(x) + x = nn.Dense(4)(x) + return x + class FooConv(nn.Module): + @nn.compact + def __call__(self, x,): + x = nn.Dense(9)(x) + x = x.reshape((1, 3, 3)) + x = nn.WeightNorm(nn.Conv(2, kernel_size=(2, 2)))(x) + x = x.reshape(1, -1) + x = nn.Dense(4)(x) + return x + class FooAttention(nn.Module): + @nn.compact + def __call__(self, x): + a = nn.Dense(4)(x) + b = nn.Dense(4)(x) + x = nn.WeightNorm(nn.attention.MultiHeadDotProductAttention(4))(a, b) + x = nn.Dense(4)(x) + return x + + key1, key2, key3 = random.split(random.PRNGKey(0), 3) + x = random.normal(key1, (1, 4)) + y = random.normal(key2, (1, 4)) + + model_cls = (FooDense, FooConv, FooAttention)[model_index] + params = model_cls().init(key3, x)['params'] + self.assertEqual(key_paths, params['WeightNorm_0'].keys()) + + state = train_state.TrainState.create( + apply_fn=model_cls().apply, + params=params, + tx=optax.adam(1e-3), + ) + + @jax.jit + def train_step(state, batch): + def loss_fn(params): + logits = state.apply_fn( + {'params': params}, + x=batch['image'], + ) + loss = jnp.mean( + optax.l2_loss(predictions=logits, targets=batch['label']) + ) + return loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + return state, loss + + prev_loss = float('inf') + for _ in range(10): + state, loss = train_step(state, {'image': x, 'label': y}) + self.assertTrue(loss < prev_loss) + prev_loss = loss + class StochasticTest(absltest.TestCase): From 2d39eba32a82d5f4f0d29fb75a109765bc1eaacf Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Fri, 6 Oct 2023 15:05:20 -0700 Subject: [PATCH 08/27] fix spectralnorm layer --- flax/linen/normalization.py | 9 +- tests/linen/linen_test.py | 221 ++++++++++++++++++------------------ 2 files changed, 111 insertions(+), 119 deletions(-) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 899351962f..8634799d57 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -761,7 +761,6 @@ def layer_forward(layer_instance): trans_in_fn=lambda vs: jax.tree_util.tree_map_with_path( functools.partial( self._spectral_normalize, - layer_instance_name=self.layer_instance.name, update_stats=update_stats, ), vs, @@ -770,7 +769,7 @@ def layer_forward(layer_instance): mutable=True, )(self.layer_instance) - def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): + def _spectral_normalize(self, path, vs, update_stats): """Compute the largest singular value using power iteration and normalize the variables ``vs`` using this value. This is intended to be a helper function used in this Module's ``__call__`` method in conjunction with @@ -779,8 +778,6 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): Args: path: dict key path, used for naming the ``u`` and ``sigma`` variables vs: variables to be spectral normalized - layer_instance_name: name of the underlying ``self.layer_instance``, - used for naming the ``u`` and ``sigma`` variables update_stats: if True, update the ``u`` vector and ``sigma`` variables after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value @@ -802,7 +799,7 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): value = jnp.reshape(value, (-1, value.shape[-1])) u_var_name = ( - layer_instance_name + self.layer_instance.name + '/' + '/'.join((dict_key.key for dict_key in path[1:])) + '/u' @@ -819,7 +816,7 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): ) u0 = u_var.value sigma_var_name = ( - layer_instance_name + self.layer_instance.name + '/' + '/'.join((dict_key.key for dict_key in path[1:])) + '/sigma' diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 4d8ea322ad..7e899df4ff 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -296,20 +296,35 @@ def __call__(self, x): (y1, y2), variables = model.init_with_output(key, x) np.testing.assert_allclose(y1, y2, rtol=1e-4) + @parameterized.parameters( + {'model_index': 0, 'key_paths': {'Dense_1/kernel/u', 'Dense_1/kernel/sigma'}}, + {'model_index': 1, 'key_paths': {'Conv_0/kernel/u', 'Conv_0/kernel/sigma'}}, + {'model_index': 2, 'key_paths': {'MultiHeadDotProductAttention_0/key/bias/u', + 'MultiHeadDotProductAttention_0/key/kernel/u', + 'MultiHeadDotProductAttention_0/out/kernel/u', + 'MultiHeadDotProductAttention_0/query/bias/u', + 'MultiHeadDotProductAttention_0/query/kernel/u', + 'MultiHeadDotProductAttention_0/value/bias/u', + 'MultiHeadDotProductAttention_0/value/kernel/u', + 'MultiHeadDotProductAttention_0/key/bias/sigma', + 'MultiHeadDotProductAttention_0/key/kernel/sigma', + 'MultiHeadDotProductAttention_0/out/kernel/sigma', + 'MultiHeadDotProductAttention_0/query/bias/sigma', + 'MultiHeadDotProductAttention_0/query/kernel/sigma', + 'MultiHeadDotProductAttention_0/value/bias/sigma', + 'MultiHeadDotProductAttention_0/value/kernel/sigma'}} + ) def test_spectral_norm_train( - self, + self, model_index, key_paths ): class FooDense(nn.Module): - @nn.compact def __call__(self, x, train): x = nn.Dense(8)(x) x = nn.SpectralNorm(nn.Dense(6))(x, update_stats=train) x = nn.Dense(4)(x) return x - class FooConv(nn.Module): - @nn.compact def __call__(self, x, train): x = nn.Dense(9)(x) @@ -320,9 +335,7 @@ def __call__(self, x, train): x = x.reshape(1, -1) x = nn.Dense(4)(x) return x - class FooAttention(nn.Module): - @nn.compact def __call__(self, x, train): a = nn.Dense(4)(x) @@ -337,123 +350,105 @@ def __call__(self, x, train): x = random.normal(key1, (1, 4)) y = random.normal(key2, (1, 4)) - for model_cls, var_paths in ( - (FooDense, ('Dense_1/kernel/',)), - (FooConv, ('Conv_0/kernel/',)), - ( - FooAttention, - ( - 'MultiHeadDotProductAttention_0/key/bias/', - 'MultiHeadDotProductAttention_0/key/kernel/', - 'MultiHeadDotProductAttention_0/out/kernel/', - 'MultiHeadDotProductAttention_0/query/bias/', - 'MultiHeadDotProductAttention_0/query/kernel/', - 'MultiHeadDotProductAttention_0/value/bias/', - 'MultiHeadDotProductAttention_0/value/kernel/', - ), - ), - ): - variables = model_cls().init(key3, x, train=False) - params, batch_stats = variables['params'], variables['batch_stats'] - for var_path in var_paths: - self.assertTrue(var_path + 'u' in batch_stats['SpectralNorm_0'].keys()) - self.assertTrue( - var_path + 'sigma' in batch_stats['SpectralNorm_0'].keys() + model_cls = (FooDense, FooConv, FooAttention)[model_index] + variables = model_cls().init(key3, x, train=False) + params, batch_stats = variables['params'], variables['batch_stats'] + self.assertEqual(key_paths, batch_stats['SpectralNorm_0'].keys()) + + class TrainState(train_state.TrainState): + batch_stats: Any + + state = TrainState.create( + apply_fn=model_cls().apply, + params=params, + batch_stats=batch_stats, + tx=optax.adam(1e-3), + ) + + @jax.jit + def train_step(state, batch): + def loss_fn(params): + logits, updates = state.apply_fn( + {'params': params, 'batch_stats': state.batch_stats}, + x=batch['image'], + train=True, + mutable=['batch_stats'], + ) + loss = jnp.mean( + optax.l2_loss(predictions=logits, targets=batch['label']) ) + return loss, updates - class TrainState(train_state.TrainState): - batch_stats: Any + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, updates), grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + state = state.replace(batch_stats=updates['batch_stats']) + return state, loss - state = TrainState.create( - apply_fn=model_cls().apply, - params=params, - batch_stats=batch_stats, - tx=optax.adam(1e-3), - ) + prev_loss = float('inf') + for _ in range(10): + state, loss = train_step(state, {'image': x, 'label': y}) + self.assertTrue(loss < prev_loss) + prev_loss = loss - @jax.jit - def train_step(state, batch): - def loss_fn(params): - logits, updates = state.apply_fn( - {'params': params, 'batch_stats': state.batch_stats}, - x=batch['image'], - train=True, - mutable=['batch_stats'], - ) - loss = jnp.mean( - optax.l2_loss(predictions=logits, targets=batch['label']) - ) - return loss, updates - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (loss, updates), grads = grad_fn(state.params) - state = state.apply_gradients(grads=grads) - state = state.replace(batch_stats=updates['batch_stats']) - return state, loss - - prev_loss = float('inf') - for _ in range(10): - state, loss = train_step(state, {'image': x, 'label': y}) - self.assertTrue(loss < prev_loss) - prev_loss = loss - - def test_spectral_norm_sigma(self): - for n_steps, update_stats, result in ( - (1, True, 4.0), - (3, True, 4.0), - (10, True, 4.0), - (1, False, 1.0), - ): - - class Foo(nn.Module): - - @nn.compact - def __call__(self, x, train): - x = nn.SpectralNorm(nn.Dense(8, use_bias=False), n_steps=n_steps)( - x, update_stats=train - ) - return x - - x = jnp.ones((1, 8)) - model_cls = Foo() - variables = model_cls.init(random.PRNGKey(0), x, train=False) - params, batch_stats = variables['params'], variables['batch_stats'] - params = jax.tree_map(lambda x: 4 * jnp.eye(*x.shape), params) - logits, updates = model_cls.apply( - {'params': params, 'batch_stats': batch_stats}, - x=x, - train=update_stats, - mutable=True, - ) - np.testing.assert_allclose( - updates['batch_stats']['SpectralNorm_0']['Dense_0/kernel/sigma'], - result, - atol=1e-3, - ) + @parameterized.parameters( + {'n_steps': 1, 'update_stats': True, 'result': 4.0}, + {'n_steps': 3, 'update_stats': True, 'result': 4.0}, + {'n_steps': 10, 'update_stats': True, 'result': 4.0}, + {'n_steps': 1, 'update_stats': False, 'result': 1.0} + ) + def test_spectral_norm_sigma(self, n_steps, update_stats, result): + class Foo(nn.Module): - def test_spectral_norm_3d_tensor(self): - for error_on_non_matrix in (True, False): + @nn.compact + def __call__(self, x, train): + x = nn.SpectralNorm(nn.Dense(8, use_bias=False), n_steps=n_steps)( + x, update_stats=train + ) + return x + + x = jnp.ones((1, 8)) + model_cls = Foo() + variables = model_cls.init(random.PRNGKey(0), x, train=False) + params, batch_stats = variables['params'], variables['batch_stats'] + params = jax.tree_map(lambda x: 4 * jnp.eye(*x.shape), params) + logits, updates = model_cls.apply( + {'params': params, 'batch_stats': batch_stats}, + x=x, + train=update_stats, + mutable=True, + ) + np.testing.assert_allclose( + updates['batch_stats']['SpectralNorm_0']['Dense_0/kernel/sigma'], + result, + atol=1e-3, + ) - class Foo(nn.Module): + @parameterized.parameters( + {'error_on_non_matrix': True}, + {'error_on_non_matrix': False} + ) + def test_spectral_norm_3d_tensor(self, error_on_non_matrix): + class Foo(nn.Module): - @nn.compact - def __call__(self, x, train): - x = nn.SpectralNorm( - nn.DenseGeneral((3, 4), use_bias=False), - error_on_non_matrix=error_on_non_matrix, - )(x, update_stats=train) - return x + @nn.compact + def __call__(self, x, train): + x = nn.SpectralNorm( + nn.DenseGeneral((3, 4), use_bias=False), + error_on_non_matrix=error_on_non_matrix, + )(x, update_stats=train) + return x - x = jnp.ones((1, 2)) - model_cls = Foo() + x = jnp.ones((1, 2)) + model_cls = Foo() - if error_on_non_matrix: - with self.assertRaisesRegex( - ValueError, 'Input is 3D but error_on_non_matrix is True' - ): - variables = model_cls.init(random.PRNGKey(0), x, train=False) - else: + if error_on_non_matrix: + with self.assertRaisesRegex( + ValueError, 'Input is 3D but error_on_non_matrix is True' + ): variables = model_cls.init(random.PRNGKey(0), x, train=False) + else: + variables = model_cls.init(random.PRNGKey(0), x, train=False) class StochasticTest(absltest.TestCase): From 5892e2bc724c5ceebacc75e5afba427b354f9b9b Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Thu, 12 Oct 2023 15:08:02 -0700 Subject: [PATCH 09/27] added attention refactor to changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43c0fa18d9..5032ab925c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,9 @@ vNext - - - -- +- Re-factored `MultiHeadDotProductAttention`'s call method signatur, by adding +`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic` +to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389). - - - From 1b452ebc1b657072b2e9a4019443662a19f7344f Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Sun, 1 Oct 2023 23:39:04 -0700 Subject: [PATCH 10/27] added dropout arg to multiheaddotproductattention --- flax/linen/attention.py | 13 ++++++++---- tests/linen/linen_attention_test.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 575620efcd..98919e51c8 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -34,7 +34,7 @@ import jax.numpy as jnp -PRNGKey = Any +PRNGKey = jax.Array Shape = Tuple[int, ...] Dtype = Any Array = Any @@ -258,6 +258,7 @@ def __call__( *, mask: Optional[Array] = None, deterministic: Optional[bool] = None, + dropout_rng: Optional[PRNGKey] = None, ): ... @@ -269,6 +270,7 @@ def __call__( inputs_kv: Array = None, mask: Optional[Array] = None, deterministic: Optional[bool] = None, + dropout_rng: Optional[PRNGKey] = None, ): ... @@ -282,6 +284,7 @@ def __call__( inputs_kv: Optional[Array] = None, mask: Optional[Array] = None, deterministic: Optional[bool] = None, + dropout_rng: Optional[PRNGKey] = None ): """Applies multi-head dot product attention on the input data. @@ -306,6 +309,8 @@ def __call__( corresponding mask value is `False`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. + dropout_rng: optional rng key to pass to the attention layer's dropout + mask. Otherwise, self.make_rng('dropout') is used instead. Returns: output of shape `[batch_sizes..., length, features]`. @@ -434,14 +439,13 @@ def __call__( ), ) - dropout_rng = None if ( self.dropout_rate > 0.0 ): # Require `deterministic` only if using dropout. m_deterministic = merge_param( 'deterministic', self.deterministic, deterministic ) - if not m_deterministic: + if not m_deterministic and dropout_rng is None: dropout_rng = self.make_rng('dropout') else: m_deterministic = True @@ -485,6 +489,7 @@ def __call__( # type: ignore inputs_q: Array, mask: Optional[Array] = None, deterministic: Optional[bool] = None, + dropout_rng: Optional[PRNGKey] = None ): """Applies multi-head dot product self-attention on the input data. @@ -508,7 +513,7 @@ def __call__( # type: ignore 'for more information.', DeprecationWarning) return super().__call__( - inputs_q, mask=mask, deterministic=deterministic + inputs_q, mask=mask, deterministic=deterministic, dropout_rng=dropout_rng ) diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index aba47516e4..9e26394040 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -94,6 +94,37 @@ def test_multihead_self_attention_w_dropout(self): y, _ = sa_module.init_with_output(rngs, x) self.assertEqual(y.shape, x.shape) + def test_multihead_self_attention_explicit_dropout(self): + class Foo(nn.Module): + attention_kwargs: dict + @nn.compact + def __call__(self, x, dropout_rng=None): + a = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, x, dropout_rng=dropout_rng) + b = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, x, dropout_rng=dropout_rng) + return a, b + + module = Foo( + dict( + num_heads=8, + qkv_features=16, + kernel_init=initializers.ones, + bias_init=initializers.zeros, + dropout_rate=0.5, + deterministic=False, + ) + ) + rng1, rng2, rng3 = random.split(random.key(0), 3) + x = jnp.ones((4, 2, 3, 5)) + rngs = {'params': rng1, 'dropout': rng2} + v = module.init(rngs, x) + a, b = module.apply(v, x, rngs=rngs) + self.assertTrue(not (a == b).all()) + a, b = module.apply(v, x, rngs=rngs, dropout_rng=rng3) + self.assertTrue((a == b).all()) + a, b = module.apply(v, x, dropout_rng=rng3) + self.assertTrue((a == b).all()) + self.assertTrue(a.shape == b.shape == x.shape) + def test_multihead_self_attention_w_dropout_disabled(self): rng = random.key(0) x = jnp.ones((4, 2, 3, 5)) From 3db72f59638bd24561dd9092498dd4ecb88dcbc5 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 12 Oct 2023 18:10:54 -0700 Subject: [PATCH 11/27] Remove import checks that are no longer valid. PiperOrigin-RevId: 573064287 --- flax/training/checkpoints.py | 61 ++++++++++++++---------------------- 1 file changed, 24 insertions(+), 37 deletions(-) diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 792671f3ec..2bc2260a01 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -40,24 +40,13 @@ from jax import monitoring from jax import process_index from jax import tree_util as jtu +from jax.experimental.array_serialization.serialization import get_tensorstore_spec +from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager from jax.experimental.multihost_utils import sync_global_devices -import numpy as np import orbax.checkpoint as ocp _READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec' _WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec' -_IMPORT_GDAM_SUCCESSFUL = False -try: - from jax.experimental.array_serialization.serialization import get_tensorstore_spec - from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager - - _IMPORT_GDAM_SUCCESSFUL = True -except ImportError: - logging.warning( - 'GlobalAsyncCheckpointManager is not imported correctly. ' - 'Checkpointing of GlobalDeviceArrays will not be available.' - 'To use the feature, install tensorstore.' - ) # Single-group reg-exps for int or float numerical substrings. @@ -262,7 +251,7 @@ def _restore_mpas( target: Optional[Any], ckpt_path: str, step: Optional[Union[int, float]], - gda_manager: Optional[Any], + gda_manager: Optional[GlobalAsyncCheckpointManager], allow_partial: bool = False, ): """Restore the multiprocess arrays given the target structure and type.""" @@ -740,7 +729,7 @@ def save_checkpoint_multiprocess( overwrite: bool = False, keep_every_n_steps: Optional[int] = None, async_manager: Optional[AsyncManager] = None, - gda_manager: Optional[Any] = None, + gda_manager: Optional[GlobalAsyncCheckpointManager] = None, orbax_checkpointer: Optional[ocp.Checkpointer] = None, ) -> str: """Save a checkpoint of the model in multi-process environment. @@ -768,15 +757,15 @@ def save_checkpoint_multiprocess( async_manager: if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly. - gda_manager: required if target contains a JAX GlobalDeviceArray. Type - should be GlobalAsyncCheckpointManager (needs Tensorstore to be imported - correctly). Will save the GDAs to a separate subdirectory with postfix - "_gda" asynchronously. Same as async_manager, this will block subsequent - saves. + gda_manager: required if target contains a JAX GlobalDeviceArray. Will save + the GDAs to a separate subdirectory with postfix "_gda" asynchronously. + Same as async_manager, this will block subsequent saves. orbax_checkpointer: if defined, the save will be done by Orbax In the - future, all Flax checkpointing features will be migrated to Orbax, - and starting to use an `orbax_checkpointer` is recommended. Please - check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers. + future, all Flax checkpointing features will be migrated to Orbax, and + starting to use an `orbax_checkpointer` is recommended. Please check out + the checkpointing guide + (https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#save-checkpoints) + for how to use Orbax checkpointers. Returns: Filename of saved checkpoint. @@ -850,7 +839,7 @@ def save_checkpoint_multiprocess( target = serialization.to_state_dict(target) target, mpa_targets = _split_mp_arrays(target) target = serialization.msgpack_serialize(target) - has_mpa = mpa_targets and _IMPORT_GDAM_SUCCESSFUL + has_mpa = bool(mpa_targets) if not overwrite: _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore @@ -989,7 +978,7 @@ def restore_checkpoint( step: Optional[Union[int, float]] = None, prefix: str = 'checkpoint_', parallel: bool = True, - gda_manager: Optional[Any] = None, + gda_manager: Optional[GlobalAsyncCheckpointManager] = None, allow_partial_mpa_restoration: bool = False, orbax_checkpointer: Optional[ocp.Checkpointer] = None, orbax_transforms: Optional[Dict] = None, @@ -1014,9 +1003,8 @@ def restore_checkpoint( prefix: str: name prefix of checkpoint files. parallel: bool: whether to load seekable checkpoints in parallel, for speed. gda_manager: required if checkpoint contains a multiprocess array - (GlobalDeviceArray or jax Array from pjit). Type should be - GlobalAsyncCheckpointManager (needs Tensorstore to be imported correctly). - Will read the arrays from the separate subdirectory with postfix "_gda". + (GlobalDeviceArray or jax Array from pjit). Will read the arrays from the + separate subdirectory with postfix "_gda". allow_partial_mpa_restoration: If true, the given `target` doesn't have to contain all valid multiprocess arrays. As a result, the restored Pytree may have some MPAs not restored correctly. Use this if you cannot provide @@ -1126,15 +1114,14 @@ def read_chunk(i): checkpoint_contents = fp.read() state_dict = serialization.msgpack_restore(checkpoint_contents) - if _IMPORT_GDAM_SUCCESSFUL: - state_dict = _restore_mpas( - state_dict, - target, - ckpt_path, - step, - gda_manager, - allow_partial_mpa_restoration, - ) + state_dict = _restore_mpas( + state_dict, + target, + ckpt_path, + step, + gda_manager, + allow_partial_mpa_restoration, + ) if target is None: restored_checkpoint = state_dict From 60798cfeee2a3669c2e395b3edb894f749c861a5 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 12 Oct 2023 18:10:54 -0700 Subject: [PATCH 12/27] Remove import checks that are no longer valid. PiperOrigin-RevId: 573064287 --- .github/analytics/get_repo_metrics.py | 1 - .pre-commit-config.yaml | 4 + docs/_ext/codediff.py | 1 - docs/conf.py | 1 - .../imagenet/imagenet_fake_data_benchmark.py | 1 - examples/imagenet/train.py | 1 - .../linen_design_test/attention_simple.py | 6 +- examples/linen_design_test/autoencoder.py | 5 +- examples/linen_design_test/dense.py | 3 +- .../linen_design_test/linear_regression.py | 3 +- examples/linen_design_test/mlp_explicit.py | 6 +- examples/linen_design_test/mlp_inline.py | 6 +- examples/linen_design_test/mlp_lazy.py | 3 +- .../linen_design_test/tied_autoencoder.py | 45 ----------- examples/linen_design_test/weight_std.py | 62 --------------- examples/seq2seq/models.py | 4 +- examples/sst2/models_test.py | 1 - examples/wmt/bleu.py | 1 - flax/core/frozen_dict.py | 2 +- flax/core/meta.py | 2 +- flax/core/nn/attention.py | 3 - flax/core/scope.py | 1 - flax/cursor.py | 2 +- flax/linen/dtypes.py | 1 - flax/linen/fp8_ops.py | 1 - flax/linen/linear.py | 1 - flax/linen/recurrent.py | 2 +- flax/linen/spmd.py | 2 - flax/linen/stochastic.py | 2 +- flax/struct.py | 2 +- flax/training/checkpoints.py | 61 ++++++--------- flax/traverse_util.py | 2 +- pyproject.toml | 18 ++++- tests/core/core_lift_test.py | 3 +- tests/core/core_scope_test.py | 2 - tests/core/design/core_auto_encoder_test.py | 6 +- tests/core/design/core_big_resnets_test.py | 4 +- tests/core/design/core_flow_test.py | 2 +- tests/early_stopping_test.py | 3 - tests/linen/linen_attention_test.py | 1 - tests/linen/linen_dtypes_test.py | 3 - tests/linen/linen_linear_test.py | 2 - tests/linen/linen_meta_test.py | 1 - tests/linen/linen_module_test.py | 3 +- tests/linen/linen_recurrent_test.py | 3 - tests/linen/linen_test.py | 1 - tests/linen/linen_transforms_test.py | 4 +- tests/linen/toplevel_test.py | 75 ------------------- tests/struct_test.py | 1 - 49 files changed, 72 insertions(+), 298 deletions(-) delete mode 100644 examples/linen_design_test/tied_autoencoder.py delete mode 100644 examples/linen_design_test/weight_std.py delete mode 100644 tests/linen/toplevel_test.py diff --git a/.github/analytics/get_repo_metrics.py b/.github/analytics/get_repo_metrics.py index 600270a0df..d936c1c7bb 100644 --- a/.github/analytics/get_repo_metrics.py +++ b/.github/analytics/get_repo_metrics.py @@ -15,7 +15,6 @@ import json import os from datetime import datetime -from pathlib import Path from typing import Callable, List from absl import app, flags diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 803e5ffae2..287b939df8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,3 +34,7 @@ repos: --extra-keys, "metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab", ] +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.292 + hooks: + - id: ruff diff --git a/docs/_ext/codediff.py b/docs/_ext/codediff.py index eabc634fd8..8a8e6f3fb9 100644 --- a/docs/_ext/codediff.py +++ b/docs/_ext/codediff.py @@ -26,7 +26,6 @@ In order to highlight a line of code, append "#!" to it. """ -import itertools from typing import List, Tuple from docutils import nodes diff --git a/docs/conf.py b/docs/conf.py index 653e365796..ffaf16ea17 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -37,7 +37,6 @@ sys.path.append(os.path.abspath('./_ext')) # patch sphinx -import docs.conf_sphinx_patch # -- Project information ----------------------------------------------------- project = 'Flax' diff --git a/examples/imagenet/imagenet_fake_data_benchmark.py b/examples/imagenet/imagenet_fake_data_benchmark.py index 377a533acb..9e532adc90 100644 --- a/examples/imagenet/imagenet_fake_data_benchmark.py +++ b/examples/imagenet/imagenet_fake_data_benchmark.py @@ -22,7 +22,6 @@ import time from absl.testing import absltest -from absl.testing.flagsaver import flagsaver from flax.testing import Benchmark import jax import tensorflow_datasets as tfds diff --git a/examples/imagenet/train.py b/examples/imagenet/train.py index b83b97de68..a32b2283e6 100644 --- a/examples/imagenet/train.py +++ b/examples/imagenet/train.py @@ -25,7 +25,6 @@ from absl import logging from clu import metric_writers from clu import periodic_actions -import flax from flax import jax_utils from flax.training import checkpoints from flax.training import common_utils diff --git a/examples/linen_design_test/attention_simple.py b/examples/linen_design_test/attention_simple.py index e77d43c1b3..6490d66329 100644 --- a/examples/linen_design_test/attention_simple.py +++ b/examples/linen_design_test/attention_simple.py @@ -14,15 +14,13 @@ import functools from pprint import pprint -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union -from flax.core import Scope -from flax.core.frozen_dict import freeze, unfreeze +from typing import Any, Callable, Optional, Sequence +from flax.core.frozen_dict import unfreeze from flax.linen import initializers from flax.linen import Module, compact, vmap from flax.linen.linear import PrecisionLike import jax from jax import lax, numpy as jnp, random -import numpy as np class Dense(Module): diff --git a/examples/linen_design_test/autoencoder.py b/examples/linen_design_test/autoencoder.py index 7c6a1fc9c6..cec1c1dea9 100644 --- a/examples/linen_design_test/autoencoder.py +++ b/examples/linen_design_test/autoencoder.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, Tuple import jax -from jax import numpy as jnp, random, lax -import numpy as np +from jax import numpy as jnp, random from flax import linen as nn from flax.linen import Module, Dense, compact diff --git a/examples/linen_design_test/dense.py b/examples/linen_design_test/dense.py index 78088ccbbc..7a109a95fe 100644 --- a/examples/linen_design_test/dense.py +++ b/examples/linen_design_test/dense.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax -from jax import numpy as jnp, random, lax +from jax import lax from flax.linen import initializers from typing import Callable from flax.linen import Module, compact diff --git a/examples/linen_design_test/linear_regression.py b/examples/linen_design_test/linear_regression.py index 8bda1e1112..bd6a6812d8 100644 --- a/examples/linen_design_test/linear_regression.py +++ b/examples/linen_design_test/linear_regression.py @@ -13,8 +13,7 @@ # limitations under the License. import jax -from jax import numpy as jnp, random, lax, jit -from flax import linen as nn +from jax import numpy as jnp, jit from dense import Dense diff --git a/examples/linen_design_test/mlp_explicit.py b/examples/linen_design_test/mlp_explicit.py index 9953c4df4a..a2665018cb 100644 --- a/examples/linen_design_test/mlp_explicit.py +++ b/examples/linen_design_test/mlp_explicit.py @@ -13,14 +13,12 @@ # limitations under the License. from pprint import pprint -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Optional from flax.deprecated import nn -from flax.deprecated.nn import initializers from dense import Dense from flax.linen import Module import jax -from jax import lax, numpy as jnp, random -import numpy as np +from jax import numpy as jnp # Add `in_features` to the built-in Dense layer that normally works diff --git a/examples/linen_design_test/mlp_inline.py b/examples/linen_design_test/mlp_inline.py index b631d19d83..3759695adb 100644 --- a/examples/linen_design_test/mlp_inline.py +++ b/examples/linen_design_test/mlp_inline.py @@ -13,12 +13,10 @@ # limitations under the License. import jax -from jax import numpy as jnp, random, lax +from jax import numpy as jnp from flax import linen as nn -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable from flax.linen import Module, compact -import numpy as np -from pprint import pprint from dense import Dense diff --git a/examples/linen_design_test/mlp_lazy.py b/examples/linen_design_test/mlp_lazy.py index 7e246917bf..cff15d564b 100644 --- a/examples/linen_design_test/mlp_lazy.py +++ b/examples/linen_design_test/mlp_lazy.py @@ -13,10 +13,9 @@ # limitations under the License. import jax -from jax import numpy as jnp, random, lax +from jax import numpy as jnp from flax import linen as nn from flax.linen import Module -import numpy as np from pprint import pprint from dense import Dense diff --git a/examples/linen_design_test/tied_autoencoder.py b/examples/linen_design_test/tied_autoencoder.py deleted file mode 100644 index a7dadeef19..0000000000 --- a/examples/linen_design_test/tied_autoencoder.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -from jax import numpy as jnp, random, lax -from flax import linen as nn -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union -from flax.linen import Module, compact -import numpy as np -from dense import Dense - - -# TODO(avital, levskaya): resurrect this example once interactive api is restored. - -# class TiedAutoEncoder(Module): -# def setup(self): -# self.encoder = Dense(features=4, use_bias=False) - -# @property -# def decoder(self): -# return self.encoder.detached().attached(variables={ -# 'params': {"kernel": self.encoder.variables['params']['kernel'].T}}) - -# def __call__(self, x): -# z = self.encoder(x) -# x = self.decoder(z) -# return x - -# tae = TiedAutoEncoder(parent=None) -# tae = tae.initialized( -# {'params': random.key(42)}, -# jnp.ones((1, 16))) -# print("reconstruct", jnp.shape(tae(jnp.ones((1, 16))))) -# print("var shapes", jax.tree_util.tree_map(jnp.shape, tae.variables)) diff --git a/examples/linen_design_test/weight_std.py b/examples/linen_design_test/weight_std.py deleted file mode 100644 index c384a00b91..0000000000 --- a/examples/linen_design_test/weight_std.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -import jax -from jax import numpy as jnp, random, lax, jit -from flax import linen as nn -from flax.core.scope import Scope -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union -from flax.linen import Module, compact -import numpy as np -from dense import Dense -from flax.core.frozen_dict import freeze, unfreeze, FrozenDict - - -def standardize(x, axis, eps=1e-8): - x = x - jnp.mean(x, axis=axis, keepdims=True) - x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps) - return x - - -# TODO(avital, levskaya): resurrect this example once interactive api is restored. - -# A wrapper that calls through a simple module with standardized parameters. -# -# Note that StdWeight is /not/ a module, hence it doesn't add another layer -# of depth in the variable dict (i.e. this is a "transparent module") -# @dataclass -# class StdWeight: -# module: Module - -# def __call__(self, x): -# # TODO: Think about how this modifies other state -# if not 'params' in self.module.variables: -# # initialize parameters -# self.module(x) - -# param = self.module.variables['params'] -# # Make a copy because `param` is (and should be) frozen. We're only transforming -# # the parameters, not mutating them. -# std_param = param.copy(kernel=standardize(param['kernel'], axis=[0, 1])) -# return self.module.clone(parent=None).apply({'params': std_param}, x) - -# class MyModule(Module): -# def __call__(self, x): -# module = Dense(self, 3) -# std_module = StdWeight(module) -# return std_module(x) - -# m_variables = MyModule().init({'params': jax.random.key(10)}, jnp.ones((1, 4))) -# print(m_variables) diff --git a/examples/seq2seq/models.py b/examples/seq2seq/models.py index 5870a62e90..6a9d43c350 100644 --- a/examples/seq2seq/models.py +++ b/examples/seq2seq/models.py @@ -17,13 +17,11 @@ # See issue #620. # pytype: disable=wrong-keyword-args -import functools -from typing import Any, Tuple +from typing import Tuple from flax import linen as nn import jax import jax.numpy as jnp -import numpy as np Array = jax.Array PRNGKey = jax.Array diff --git a/examples/sst2/models_test.py b/examples/sst2/models_test.py index c1a42a0c02..bea495d1ec 100644 --- a/examples/sst2/models_test.py +++ b/examples/sst2/models_test.py @@ -17,7 +17,6 @@ from absl.testing import parameterized import models import jax -from jax import numpy as jnp import jax.test_util import numpy as np diff --git a/examples/wmt/bleu.py b/examples/wmt/bleu.py index e12911dd6e..ac69cc956d 100644 --- a/examples/wmt/bleu.py +++ b/examples/wmt/bleu.py @@ -44,7 +44,6 @@ import unicodedata import numpy as np -import six class UnicodeRegex: diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 8c5c5646f1..4e7f25fb23 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -15,7 +15,7 @@ """Frozen Dictionary.""" import collections -from typing import Any, Dict, Hashable, Optional, Mapping, Tuple, TypeVar, Union +from typing import Any, Dict, Hashable, Mapping, Tuple, TypeVar, Union from types import MappingProxyType from flax import serialization diff --git a/flax/core/meta.py b/flax/core/meta.py index ec030450c2..b589f90bad 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -23,7 +23,7 @@ import abc import functools -from typing import Any, Callable, Dict, Generic, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar, Union from flax import errors from flax import struct diff --git a/flax/core/nn/attention.py b/flax/core/nn/attention.py index bea380a2ea..212682e4f0 100644 --- a/flax/core/nn/attention.py +++ b/flax/core/nn/attention.py @@ -17,10 +17,7 @@ from collections.abc import Iterable # pylint: disable=g-importing-member import functools from typing import Any, Callable, Union -import warnings -from . import stochastic -from flax import jax_utils from flax import struct from flax.core import Scope from flax.linen import initializers diff --git a/flax/core/scope.py b/flax/core/scope.py index 6059dcb9ad..f15f980f5d 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -45,7 +45,6 @@ from flax import traceback_util from flax.ids import uuid import jax -from jax import config as jax_config from jax import numpy as jnp from jax import random from jax import tree_util diff --git a/flax/cursor.py b/flax/cursor.py index 8f7be74c85..d919782501 100644 --- a/flax/cursor.py +++ b/flax/cursor.py @@ -13,7 +13,7 @@ # limitations under the License. import enum -from typing import Any, Callable, Dict, Generator, Generic, Mapping, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable +from typing import Any, Callable, Dict, Generator, Generic, Mapping, Optional, Protocol, TypeVar, runtime_checkable from flax.core import FrozenDict from flax.errors import CursorFindError, TraverseTreeError import dataclasses diff --git a/flax/linen/dtypes.py b/flax/linen/dtypes.py index 463b0f23c4..bef29a2f02 100644 --- a/flax/linen/dtypes.py +++ b/flax/linen/dtypes.py @@ -30,7 +30,6 @@ from typing import Any, Optional, List from jax import numpy as jnp -import jax Dtype = Any diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 575c1e1104..11e687cd9a 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable from functools import partial from flax.linen import initializers diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 4a23bd3c1c..d3362fe82e 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -34,7 +34,6 @@ import jax from jax import eval_shape from jax import lax -from jax import random from jax.core import ShapedArray import jax.numpy as jnp import numpy as np diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 7b39db5998..94151cfa0d 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -20,7 +20,7 @@ from abc import ABCMeta from functools import partial # pylint: disable=g-importing-member -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, TypeVar, Union from absl import logging from flax.core import lift from flax.core.frozen_dict import FrozenDict diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 4d0b19d5b4..59095d03b7 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -38,8 +38,6 @@ from flax import struct from flax.core import meta -from flax.core.lift import In as ScanIn # pylint: disable=unused-import -from flax.core.lift import Out as ScanOut # pylint: disable=unused-import # Real types and dummy aliases for documentation LogicalRules = Sequence[Tuple[str, Union[str, Tuple[str], None]]] diff --git a/flax/linen/stochastic.py b/flax/linen/stochastic.py index cf90a7c5c5..5ed47654aa 100644 --- a/flax/linen/stochastic.py +++ b/flax/linen/stochastic.py @@ -14,7 +14,7 @@ """Stochastic modules.""" -from typing import Optional, Sequence, Union +from typing import Optional, Sequence import jax diff --git a/flax/struct.py b/flax/struct.py index f9c299be78..f3c88a7274 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -15,7 +15,7 @@ """Utilities for defining custom classes that can be used with jax transformations.""" import dataclasses -from typing import TypeVar, Callable, Tuple, Union, Any +from typing import TypeVar from . import serialization diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 792671f3ec..2bc2260a01 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -40,24 +40,13 @@ from jax import monitoring from jax import process_index from jax import tree_util as jtu +from jax.experimental.array_serialization.serialization import get_tensorstore_spec +from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager from jax.experimental.multihost_utils import sync_global_devices -import numpy as np import orbax.checkpoint as ocp _READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec' _WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec' -_IMPORT_GDAM_SUCCESSFUL = False -try: - from jax.experimental.array_serialization.serialization import get_tensorstore_spec - from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager - - _IMPORT_GDAM_SUCCESSFUL = True -except ImportError: - logging.warning( - 'GlobalAsyncCheckpointManager is not imported correctly. ' - 'Checkpointing of GlobalDeviceArrays will not be available.' - 'To use the feature, install tensorstore.' - ) # Single-group reg-exps for int or float numerical substrings. @@ -262,7 +251,7 @@ def _restore_mpas( target: Optional[Any], ckpt_path: str, step: Optional[Union[int, float]], - gda_manager: Optional[Any], + gda_manager: Optional[GlobalAsyncCheckpointManager], allow_partial: bool = False, ): """Restore the multiprocess arrays given the target structure and type.""" @@ -740,7 +729,7 @@ def save_checkpoint_multiprocess( overwrite: bool = False, keep_every_n_steps: Optional[int] = None, async_manager: Optional[AsyncManager] = None, - gda_manager: Optional[Any] = None, + gda_manager: Optional[GlobalAsyncCheckpointManager] = None, orbax_checkpointer: Optional[ocp.Checkpointer] = None, ) -> str: """Save a checkpoint of the model in multi-process environment. @@ -768,15 +757,15 @@ def save_checkpoint_multiprocess( async_manager: if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly. - gda_manager: required if target contains a JAX GlobalDeviceArray. Type - should be GlobalAsyncCheckpointManager (needs Tensorstore to be imported - correctly). Will save the GDAs to a separate subdirectory with postfix - "_gda" asynchronously. Same as async_manager, this will block subsequent - saves. + gda_manager: required if target contains a JAX GlobalDeviceArray. Will save + the GDAs to a separate subdirectory with postfix "_gda" asynchronously. + Same as async_manager, this will block subsequent saves. orbax_checkpointer: if defined, the save will be done by Orbax In the - future, all Flax checkpointing features will be migrated to Orbax, - and starting to use an `orbax_checkpointer` is recommended. Please - check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers. + future, all Flax checkpointing features will be migrated to Orbax, and + starting to use an `orbax_checkpointer` is recommended. Please check out + the checkpointing guide + (https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#save-checkpoints) + for how to use Orbax checkpointers. Returns: Filename of saved checkpoint. @@ -850,7 +839,7 @@ def save_checkpoint_multiprocess( target = serialization.to_state_dict(target) target, mpa_targets = _split_mp_arrays(target) target = serialization.msgpack_serialize(target) - has_mpa = mpa_targets and _IMPORT_GDAM_SUCCESSFUL + has_mpa = bool(mpa_targets) if not overwrite: _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore @@ -989,7 +978,7 @@ def restore_checkpoint( step: Optional[Union[int, float]] = None, prefix: str = 'checkpoint_', parallel: bool = True, - gda_manager: Optional[Any] = None, + gda_manager: Optional[GlobalAsyncCheckpointManager] = None, allow_partial_mpa_restoration: bool = False, orbax_checkpointer: Optional[ocp.Checkpointer] = None, orbax_transforms: Optional[Dict] = None, @@ -1014,9 +1003,8 @@ def restore_checkpoint( prefix: str: name prefix of checkpoint files. parallel: bool: whether to load seekable checkpoints in parallel, for speed. gda_manager: required if checkpoint contains a multiprocess array - (GlobalDeviceArray or jax Array from pjit). Type should be - GlobalAsyncCheckpointManager (needs Tensorstore to be imported correctly). - Will read the arrays from the separate subdirectory with postfix "_gda". + (GlobalDeviceArray or jax Array from pjit). Will read the arrays from the + separate subdirectory with postfix "_gda". allow_partial_mpa_restoration: If true, the given `target` doesn't have to contain all valid multiprocess arrays. As a result, the restored Pytree may have some MPAs not restored correctly. Use this if you cannot provide @@ -1126,15 +1114,14 @@ def read_chunk(i): checkpoint_contents = fp.read() state_dict = serialization.msgpack_restore(checkpoint_contents) - if _IMPORT_GDAM_SUCCESSFUL: - state_dict = _restore_mpas( - state_dict, - target, - ckpt_path, - step, - gda_manager, - allow_partial_mpa_restoration, - ) + state_dict = _restore_mpas( + state_dict, + target, + ckpt_path, + step, + gda_manager, + allow_partial_mpa_restoration, + ) if target is None: restored_checkpoint = state_dict diff --git a/flax/traverse_util.py b/flax/traverse_util.py index e6d409ac1c..4ef0768c44 100644 --- a/flax/traverse_util.py +++ b/flax/traverse_util.py @@ -43,7 +43,7 @@ import abc import copy import dataclasses -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Tuple import warnings import jax diff --git a/pyproject.toml b/pyproject.toml index 2b60ce6375..a0633008b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,4 +167,20 @@ exclude_lines = [ pyink-indentation = 2 pyink-use-majority-quotes = true line-length = 80 -preview = true \ No newline at end of file +preview = true + +[tool.ruff] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = ["F401"] +ignore = [] +# Allow fix for all enabled rules (when `--fix`) is provided. +# Full list of rules: https://docs.astral.sh/ruff/rules/ +fixable = ["F401"] +unfixable = [] +# Exclude a variety of commonly ignored directories. +exclude = [ + "__init__.py", + "activation.py", + "partitioning.py", + "variables.py" +] diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index 74cc245c7f..51ce5911af 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -14,8 +14,7 @@ import operator from flax import errors -from flax.core import Scope, init, apply, lift, nn, FrozenDict, unfreeze, copy -from flax.configurations import temp_flip_flag +from flax.core import init, apply, lift, nn, FrozenDict, copy import jax from jax import random diff --git a/tests/core/core_scope_test.py b/tests/core/core_scope_test.py index 86634a1c8a..6f8190d7f6 100644 --- a/tests/core/core_scope_test.py +++ b/tests/core/core_scope_test.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from flax import errors from flax.core import Scope, scope, freeze, lazy_init, init, apply, nn from flax.core.scope import LazyRng from flax.configurations import temp_flip_flag import jax -from jax import config as jax_config from jax import random from jax import numpy as jnp diff --git a/tests/core/design/core_auto_encoder_test.py b/tests/core/design/core_auto_encoder_test.py index 7eff74f4a5..8c462da5b0 100644 --- a/tests/core/design/core_auto_encoder_test.py +++ b/tests/core/design/core_auto_encoder_test.py @@ -20,12 +20,10 @@ import jax from jax import numpy as jnp, random -from flax import struct -from jax.scipy.linalg import expm -from dataclasses import dataclass, InitVar -from typing import Any, Callable, Sequence, NamedTuple, Any +from dataclasses import dataclass +from typing import Callable def mlp(scope: Scope, x: Array, hidden: int, out: int): diff --git a/tests/core/design/core_big_resnets_test.py b/tests/core/design/core_big_resnets_test.py index 01d38362bf..d47f766afc 100644 --- a/tests/core/design/core_big_resnets_test.py +++ b/tests/core/design/core_big_resnets_test.py @@ -18,10 +18,10 @@ import numpy as np -from flax.core import Scope, Array, init, apply, unfreeze, lift, nn +from flax.core import Scope, Array, init, unfreeze, lift, nn import jax -from jax import lax, random, numpy as jnp +from jax import random, numpy as jnp default_norm = partial(nn.batch_norm) diff --git a/tests/core/design/core_flow_test.py b/tests/core/design/core_flow_test.py index 872674745b..d72b76a52a 100644 --- a/tests/core/design/core_flow_test.py +++ b/tests/core/design/core_flow_test.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, Sequence, NamedTuple, Any +from typing import Any, Sequence, Any from absl.testing import absltest diff --git a/tests/early_stopping_test.py b/tests/early_stopping_test.py index 92b74fb9ed..ea6ce47ddc 100644 --- a/tests/early_stopping_test.py +++ b/tests/early_stopping_test.py @@ -14,13 +14,10 @@ """Tests for flax.training.early_stopping.""" -import copy -import os from absl.testing import absltest from flax.training import early_stopping import jax -from jax import test_util as jtu # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index aba47516e4..15275c98e4 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -21,7 +21,6 @@ from flax import linen as nn from flax import jax_utils from flax.core import pop -from flax.configurations import temp_flip_flag import jax from jax import lax diff --git a/tests/linen/linen_dtypes_test.py b/tests/linen/linen_dtypes_test.py index 7233486c5e..f878960b92 100644 --- a/tests/linen/linen_dtypes_test.py +++ b/tests/linen/linen_dtypes_test.py @@ -14,12 +14,9 @@ """Tests for flax.linen.dtypes.""" -import functools -from multiprocessing.sharedctypes import Value from absl.testing import absltest -from flax import linen as nn from flax.linen import dtypes import jax diff --git a/tests/linen/linen_linear_test.py b/tests/linen/linen_linear_test.py index 715e7a7c8e..899cb5ea71 100644 --- a/tests/linen/linen_linear_test.py +++ b/tests/linen/linen_linear_test.py @@ -15,8 +15,6 @@ """Tests for flax.linen.linear.""" import functools -from multiprocessing.sharedctypes import Value -from typing import Callable, Optional from absl.testing import absltest from absl.testing import parameterized diff --git a/tests/linen/linen_meta_test.py b/tests/linen/linen_meta_test.py index eaebda2b64..9a27ea328d 100644 --- a/tests/linen/linen_meta_test.py +++ b/tests/linen/linen_meta_test.py @@ -20,7 +20,6 @@ from jax import numpy as jnp from jax import random from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit from jax.sharding import Mesh from jax.sharding import PartitionSpec diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index e894b7fa4a..4f409b9788 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -42,8 +42,7 @@ from flax import errors from flax import linen as nn from flax import struct -from flax.configurations import temp_flip_flag -from flax.core import FrozenDict, Scope, freeze, tracers +from flax.core import FrozenDict, Scope, freeze from flax.linen import compact import jax from jax import random diff --git a/tests/linen/linen_recurrent_test.py b/tests/linen/linen_recurrent_test.py index 74c900128b..a8ec60c4b0 100644 --- a/tests/linen/linen_recurrent_test.py +++ b/tests/linen/linen_recurrent_test.py @@ -19,10 +19,7 @@ import jax import jax.numpy as jnp import numpy as np -from flax import errors from flax import linen as nn -import pytest -import einops from flax.linen.recurrent import flip_sequences # Parse absl flags test_srcdir and test_tmpdir. diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 5a79e29ee7..3049d04f32 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -24,7 +24,6 @@ import jax from jax import random -from jax.nn import initializers import jax.numpy as jnp import optax import numpy as np diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 6ef18b140b..a337444603 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -15,7 +15,7 @@ """Transforms tests.""" from functools import partial -from typing import Any, Tuple, Iterable, Callable, Sequence +from typing import Any, Callable, Sequence import operator import unittest @@ -24,11 +24,9 @@ from jax import random import jax.numpy as jnp import numpy as np -from flax import config from flax import errors from flax import linen as nn from flax.core import freeze, copy -from flax.configurations import temp_flip_flag # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() diff --git a/tests/linen/toplevel_test.py b/tests/linen/toplevel_test.py deleted file mode 100644 index caf6c29038..0000000000 --- a/tests/linen/toplevel_test.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -import jax -from jax import random -from jax.nn import initializers -import jax.numpy as jnp - -import numpy as np -from typing import Any, Tuple - -from flax import linen as nn -from flax.core import Scope - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - - -class Dummy(nn.Module): - - @nn.compact - def __call__(self): - self.param('foo', lambda rng: 1) - - -class ModuleTopLevelTest(absltest.TestCase): - pass - # def test_toplevel_immutable(self): - # d = Dummy(parent=None) - # with self.assertRaisesRegex(BaseException, "orphaned module"): - # d() - - # def test_toplevel_initialized_requires_rng(self): - # with self.assertRaisesRegex(BaseException, "missing 1 required.*rngs"): - # d = Dummy(parent=None).initialized() - - # def test_toplevel_initialized_with_rng(self): - # d = Dummy(parent=None).initialized(rngs={'params': random.key(0)}) - # self.assertEqual(d.variables.param.foo, 1) - - # def test_toplevel_initialized_frozen(self): - # d = Dummy(parent=None).initialized(rngs={'params': random.key(0)}) - # with self.assertRaisesRegex(BaseException, "Can't set value"): - # d.variables.param.foo = 2 - - # def test_toplevel_initialized_has_new_scope(self): - # d = Dummy(parent=None) - # # initializing should make a copy and not have any effect - # # on `d` itself. - # d_initialized = d.initialized(rngs={'params': random.key(0)}) - # # ... make sure that indeed `d` has no scope. - # self.assertIsNone(d.scope) - - # def test_can_only_call_initialized_once(self): - # d = Dummy(parent=None) - # d = d.initialized(rngs={'params': random.key(0)}) - # with self.assertRaises(BaseException): - # d.initialized(rngs={'params': random.key(0)}) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/struct_test.py b/tests/struct_test.py index 122afd5d3d..32926aea99 100644 --- a/tests/struct_test.py +++ b/tests/struct_test.py @@ -15,7 +15,6 @@ """Tests for flax.struct.""" from typing import Any -import unittest from absl.testing import absltest From fac0cef308c81399ee7fe624d2c3de13b53fb71e Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Fri, 13 Oct 2023 15:01:41 -0700 Subject: [PATCH 13/27] Make the injectable `dot_general` optional and not pre-initialized so that method interceptor works. PiperOrigin-RevId: 573327937 --- flax/linen/attention.py | 4 ++-- .../experimental/layers_with_named_axes.py | 6 +++-- flax/linen/linear.py | 22 +++++++++++++------ tests/linen/linen_module_test.py | 4 ++-- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 575620efcd..efb34b510a 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -244,8 +244,8 @@ class MultiHeadDotProductAttention(Module): decode: bool = False normalize_qk: bool = False # Deprecated, will be removed. - qkv_dot_general: DotGeneralT = lax.dot_general - out_dot_general: DotGeneralT = lax.dot_general + qkv_dot_general: Optional[DotGeneralT] = None + out_dot_general: Optional[DotGeneralT] = None qkv_dot_general_cls: Any = None out_dot_general_cls: Any = None diff --git a/flax/linen/experimental/layers_with_named_axes.py b/flax/linen/experimental/layers_with_named_axes.py index 25ee5e6276..24fde145ba 100644 --- a/flax/linen/experimental/layers_with_named_axes.py +++ b/flax/linen/experimental/layers_with_named_axes.py @@ -73,7 +73,7 @@ class Dense(nn.Module): ) kernel_axes: Tuple[str, ...] = () # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @nn.compact @@ -98,8 +98,10 @@ def __call__(self, inputs: Array) -> Array: if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general y = dot_general( inputs, kernel, diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 4a23bd3c1c..4fdf38a9d7 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -98,7 +98,7 @@ class DenseGeneral(Module): ) precision: PrecisionLike = None # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @compact @@ -178,8 +178,10 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general out = dot_general( inputs, kernel, @@ -218,7 +220,7 @@ class Dense(Module): initializers.zeros_init() ) # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @compact @@ -247,8 +249,10 @@ def __call__(self, inputs: Array) -> Array: if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general y = dot_general( inputs, kernel, @@ -350,7 +354,7 @@ class _Conv(Module): initializers.zeros_init() ) # Deprecated. Will be removed. - conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated + conv_general_dilated: Optional[ConvGeneralDilatedT] = None conv_general_dilated_cls: Any = None @property @@ -466,8 +470,10 @@ def maybe_broadcast( # create the unshared convolution kernel. if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() - else: + elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated + else: + conv_general_dilated = lax.conv_general_dilated conv_output_shape = eval_shape( lambda lhs, rhs: conv_general_dilated( # pylint: disable=g-long-lambda lhs=lhs, @@ -517,8 +523,10 @@ def maybe_broadcast( if self.shared_weights: if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() - else: + elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated + else: + conv_general_dilated = lax.conv_general_dilated y = conv_general_dilated( inputs, kernel, diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index e894b7fa4a..b9ceba9262 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -821,7 +821,7 @@ def __call__(self, x): precision = None kernel_init = init bias_init = zeros - dot_general = dot_general + dot_general = None dot_general_cls = None ) Dense_1 = Dense( @@ -833,7 +833,7 @@ def __call__(self, x): precision = None kernel_init = init bias_init = zeros - dot_general = dot_general + dot_general = None dot_general_cls = None ) )""" From afa7e337131888768514c06a815e87b798e4ceef Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 16 Oct 2023 17:37:58 +0000 Subject: [PATCH 14/27] remove pdf target --- .readthedocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index e086a43677..fa87e6d31f 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -18,7 +18,7 @@ sphinx: formats: - htmlzip - epub - - pdf + # - pdf # Optionally set the version of Python and requirements required to build your docs python: From 766485a522dc422fdc812d5fc5d70e248f8dbd7a Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Mon, 16 Oct 2023 15:46:07 -0700 Subject: [PATCH 15/27] fixed GRU docstring --- flax/linen/recurrent.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 7b39db5998..30ba68c400 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -130,7 +130,7 @@ class LSTMCell(RNNCellBase, metaclass=RNNCellCompatibilityMeta): Attributes: features: number of output features. - gate_fn: activation function used for gates (default: sigmoid) + gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform @@ -406,8 +406,8 @@ class GRUCell(RNNCellBase, metaclass=RNNCellCompatibilityMeta): .. math:: \begin{array}{ll} - r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ - z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array} @@ -415,7 +415,8 @@ class GRUCell(RNNCellBase, metaclass=RNNCellCompatibilityMeta): where x is the input and h, is the output of the previous time step. Attributes: - gate_fn: activation function used for gates (default: sigmoid) + features: number of output features. + gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform From 31eedf16d36cd96609d8497ae997c3ffcff63ea2 Mon Sep 17 00:00:00 2001 From: Flax Team Date: Tue, 17 Oct 2023 15:54:30 -0700 Subject: [PATCH 16/27] Add instrumentation for Flax checkpoint save and restore calls. PiperOrigin-RevId: 574289364 --- flax/training/checkpoints.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 2bc2260a01..1076fb2e9a 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -620,6 +620,7 @@ def save_checkpoint( Returns: Filename of saved checkpoint. """ + jax.monitoring.record_event('/jax/flax/checkpoint/save') start_time = time.time() # Make sure all saves are finished before the logic of checking and removing # outdated checkpoints happens. @@ -770,6 +771,7 @@ def save_checkpoint_multiprocess( Returns: Filename of saved checkpoint. """ + jax.monitoring.record_event('/jax/flax/checkpoint/save') start_time = time.time() # Make sure all saves are finished before the logic of checking and removing # outdated checkpoints happens. @@ -1022,6 +1024,7 @@ def restore_checkpoint( returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created. """ + jax.monitoring.record_event('/jax/flax/checkpoint/restore') start_time = time.time() # Make sure any previous work is done before checking files. if orbax_checkpointer and isinstance( From a8bc3673fed5b6b86ac5ba412c2b8fa8cbb8b470 Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Tue, 17 Oct 2023 23:27:01 -0700 Subject: [PATCH 17/27] Replaces pjit with jit in spmd.py PiperOrigin-RevId: 574373680 --- flax/linen/spmd.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 59095d03b7..fea7c75c98 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for working with pjit and partitioned models. +"""Utilities for working with jit and partitioned models. This module introduces `axis_rules`, `logical_to_mesh_axes`, -`logical_to_mesh`, `with_logical_constraint` for appyling pjit -sharding constraints in terms of "logical named axes" rather than -pjit's default mesh axes. +`logical_to_mesh`, `with_logical_constraint` for appyling jit sharding +constraints in terms of "logical named axes" rather than jit's default mesh +axes. Additionally the `LogicallyPartitioned` metadata wrapper is defined as well as the initializer function wrapper `with_logical_partitioning` for @@ -206,7 +206,7 @@ def logical_to_mesh_sharding( def _global_mesh_defined() -> bool: - """Checks if global xmap/pjit mesh resource environment is defined.""" + """Checks if global xmap/jit mesh resource environment is defined.""" maps_env = maps.thread_resources.env return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison @@ -224,7 +224,7 @@ def _with_sharding_constraint( axis_resources: Optional[jax.sharding.PartitionSpec], mesh: Optional[jax.sharding.Mesh] = None, ): - """Wrapper for lax.with_sharding_constraint, no-op on cpu or outside pjit.""" + """Wrapper for lax.with_sharding_constraint, no-op on cpu or outside jit.""" if jax.devices()[0].platform == 'cpu' or ( not _global_mesh_defined() and mesh is None ): @@ -274,7 +274,7 @@ def with_logical_constraint( mesh: Optional[jax.sharding.Mesh] = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, ): - """Version of pjit's with_sharding_constraint that uses logical axis names.""" + """Version of jit's with_sharding_constraint that uses logical axis names.""" # If no axis binding is set, this is a no-op. if rules is None: rules = _axis_rules.rules From 47d1c51497fa8dc2fc80180c347366a1aad198a4 Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Thu, 19 Oct 2023 16:16:53 -0700 Subject: [PATCH 18/27] Ignore transient chex deprecationwarning. --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a0633008b1..5195879442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,8 @@ filterwarnings = [ "ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning", # DeprecationWarning: ml_dtypes.float8_e4m3b11 is deprecated. "ignore:.*ml_dtypes.float8_e4m3b11 is deprecated.*:DeprecationWarning", + # DeprecationWarning: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any]. (chex, recheck by Nov 2023) + "ignore:.*jax.core.Shape is deprecated.*:DeprecationWarning", ] [tool.coverage.report] From abab11fff54e229cb2691ebf71a7515abbd0a547 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 20 Oct 2023 05:10:58 -0700 Subject: [PATCH 19/27] flax: avoid deprecated jax.random.default_prng_impl() This function is being removed from JAX's public API (see https://github.com/google/jax/pull/18197), as part of the jax.random enhancements described in [JEP 9263](https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html) PiperOrigin-RevId: 575187159 --- flax/core/scope.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/flax/core/scope.py b/flax/core/scope.py index f15f980f5d..60c523bfbf 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -954,9 +954,7 @@ def param( """ self.reserve(name, 'params') if self.has_variable('params', name): - abs_rng = jax.ShapeDtypeStruct( - random.default_prng_impl().key_shape, jnp.uint32 - ) + abs_rng = jax.eval_shape(lambda s: random.key_data(random.key(s)), 0) value = self.get_variable('params', name) # Validate that the shape of the init_fn output is the same as the shape # of the existing parameter. This is to make sure that the hparams set up @@ -1201,10 +1199,10 @@ def _is_valid_rng(rng: Array): return rng.shape == () # Handle old-style raw PRNG keys - if ( - rng.shape != random.default_prng_impl().key_shape - or rng.dtype != jnp.uint32 - ): + expected_rng = jax.eval_shape( + lambda s: jax.random.key_data(jax.random.key(s)), 0 + ) + if (rng.shape, rng.dtype) != (expected_rng.shape, expected_rng.dtype): return False return True From 39ad403e1a9a9dfd3961949d4631a1ee7f4bba28 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 23 Oct 2023 22:20:49 +0000 Subject: [PATCH 20/27] remove transformers dependency --- docs/conf.py | 1 + docs/requirements.txt | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index ffaf16ea17..23673e951b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,6 +133,7 @@ nb_execution_excludepatterns = [ 'getting_started.ipynb', # <-- times out 'optax_update_guide.ipynb', # <-- requires flax<=0.5.3 + 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False diff --git a/docs/requirements.txt b/docs/requirements.txt index cc9c9e1873..600735f0b0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -32,4 +32,3 @@ tensorflow_text>=2.11.0 # WMT example # notebooks einops -transformers[flax] From 578f7d442b8a0fd3d481b6f900c9c803a45423d0 Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Fri, 20 Oct 2023 16:53:20 -0700 Subject: [PATCH 21/27] Simplify abstract rng creation in param shape-check. --- flax/core/scope.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flax/core/scope.py b/flax/core/scope.py index 60c523bfbf..46819a76b7 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -954,14 +954,13 @@ def param( """ self.reserve(name, 'params') if self.has_variable('params', name): - abs_rng = jax.eval_shape(lambda s: random.key_data(random.key(s)), 0) value = self.get_variable('params', name) # Validate that the shape of the init_fn output is the same as the shape # of the existing parameter. This is to make sure that the hparams set up # in a Flax Module match the shapes coming in during apply, and if not, # catch it with an error message. # NOTE: We could consider moving this to `self.` - abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng) + abs_value = jax.eval_shape(lambda: init_fn(random.key(0), *init_args)) abs_value_flat = jax.tree_util.tree_leaves(abs_value) value_flat = jax.tree_util.tree_leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): From 6c208fc63aa10e712b9616fcb59c505e75727973 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Mon, 23 Oct 2023 17:31:42 -0700 Subject: [PATCH 22/27] updated torch ConvTranspose --- docs/guides/convert_pytorch_to_flax.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/guides/convert_pytorch_to_flax.rst b/docs/guides/convert_pytorch_to_flax.rst index ff4e58acd3..29a6bbc9e4 100644 --- a/docs/guides/convert_pytorch_to_flax.rst +++ b/docs/guides/convert_pytorch_to_flax.rst @@ -266,6 +266,8 @@ while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolut implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_ that contains an implementation. +To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need + .. _`pull request`: https://github.com/google/jax/pull/5772 .. |nn.ConvTranspose| replace:: ``nn.ConvTranspose`` From e3633935540a5e7e321d1f10441826cb5cdafd04 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Mon, 23 Oct 2023 17:33:36 -0700 Subject: [PATCH 23/27] updated torch ConvTranspose --- docs/guides/convert_pytorch_to_flax.rst | 36 +++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/docs/guides/convert_pytorch_to_flax.rst b/docs/guides/convert_pytorch_to_flax.rst index 29a6bbc9e4..43283d448d 100644 --- a/docs/guides/convert_pytorch_to_flax.rst +++ b/docs/guides/convert_pytorch_to_flax.rst @@ -85,7 +85,7 @@ Convolutions and FC Layers We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then reshaped to [N, C * H * W] before being fed to the fc layers. -When we port our weights from PyToch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. +When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W]. Consider this PyTorch model: @@ -266,7 +266,39 @@ while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolut implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_ that contains an implementation. -To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need +To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the ``transpose_kernel`` arg in Flax's +``nn.ConvTranspose`` layer. + +.. testcode:: + + # padding is inverted + torch_padding = 0 + flax_padding = 1 - torch_padding + + t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=torch_padding) + + kernel = t_conv.weight.detach().cpu().numpy() + bias = t_conv.bias.detach().cpu().numpy() + + # [inC, outC, kH, kW] -> [kH, kW, outC, inC] + kernel = jnp.transpose(kernel, (2, 3, 1, 0)) + + key = random.key(0) + x = random.normal(key, (1, 6, 6, 3)) + + variables = {'params': {'kernel': kernel, 'bias': bias}} + # ConvTranspose expects the kernel to be [kH, kW, inC, outC], + # but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead + j_conv = nn.ConvTranspose(features=4, kernel_size=(2, 2), padding=flax_padding, transpose_kernel=True) + j_out = j_conv.apply(variables, x) + + # [N, H, W, C] -> [N, C, H, W] + t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) + t_out = t_conv(t_x) + # [N, C, H, W] -> [N, H, W, C] + t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) + np.testing.assert_almost_equal(j_out, t_out, decimal=6) + .. _`pull request`: https://github.com/google/jax/pull/5772 From 41273d18ffa1678ba228848eb098826623173d21 Mon Sep 17 00:00:00 2001 From: Flax Team Date: Mon, 23 Oct 2023 17:49:33 -0700 Subject: [PATCH 24/27] Supported computing mean and variance of BatchNorm using mask. PiperOrigin-RevId: 575979930 --- flax/linen/normalization.py | 95 ++++++++++------ tests/linen/linen_test.py | 221 +++++++++++++++++++++++------------- 2 files changed, 202 insertions(+), 114 deletions(-) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 99caf0038e..93338a32d7 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -14,13 +14,13 @@ """Normalization modules for Flax.""" +import dataclasses import functools -from dataclasses import field from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union -from flax.linen.dtypes import canonicalize_dtype -from flax.linen.module import Module, compact, merge_param # pylint: disable=g-multiple-import -from flax.linen.transforms import map_variables +from flax.linen import dtypes +from flax.linen import module +from flax.linen import transforms import jax from jax import lax from jax.nn import initializers @@ -32,6 +32,8 @@ Shape = Tuple[int, ...] Dtype = Any # this could be a real type? Axes = Union[int, Sequence[int]] +compact = module.compact +Module = module.Module def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: @@ -57,6 +59,7 @@ def _compute_stats( axis_index_groups: Any = None, use_mean: bool = True, use_fast_variance: bool = True, + mask: Optional[Array] = None, ): """Computes mean and variance statistics. @@ -82,6 +85,8 @@ def _compute_stats( variance without subtracting the mean. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. + mask: Binary array of shape broadcastable to `inputs` tensor, indicating + the positions for which the mean and variance should be computed. Returns: A pair ``(mean, var)``. @@ -94,8 +99,8 @@ def _compute_stats( x = jnp.asarray(x, dtype) axes = _canonicalize_axes(x.ndim, axes) - def maybe_distributed_mean(*xs): - mus = tuple(x.mean(axes) for x in xs) + def maybe_distributed_mean(*xs, mask=None): + mus = tuple(x.mean(axes, where=mask) for x in xs) if axis_name is None: return mus if len(xs) > 1 else mus[0] else: @@ -112,15 +117,17 @@ def maybe_distributed_mean(*xs): if use_mean: if use_fast_variance: - mu, mu2 = maybe_distributed_mean(x, _abs_sq(x)) + mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0.0, mu2 - _abs_sq(mu)) else: - mu = maybe_distributed_mean(x) - var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes))) + mu = maybe_distributed_mean(x, mask=mask) + var = maybe_distributed_mean( + _abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask + ) else: - var = maybe_distributed_mean(_abs_sq(x)) + var = maybe_distributed_mean(_abs_sq(x), mask=mask) mu = jnp.zeros_like(var) return mu, var @@ -188,7 +195,7 @@ def _normalize( ).reshape(feature_shape) y += bias args.append(bias) - dtype = canonicalize_dtype(*args, dtype=dtype) + dtype = dtypes.canonicalize_dtype(*args, dtype=dtype) return jnp.asarray(y, dtype) @@ -281,7 +288,7 @@ class BatchNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, use_running_average: Optional[bool] = None): + def __call__(self, x, use_running_average: Optional[bool] = None, mask=None): """Normalizes the input using batch statistics. NOTE: @@ -295,12 +302,14 @@ def __call__(self, x, use_running_average: Optional[bool] = None): x: the input to be normalized. use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. + mask: Binary array of shape broadcastable to `inputs` tensor, indicating + the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ - use_running_average = merge_param( + use_running_average = module.merge_param( 'use_running_average', self.use_running_average, use_running_average ) feature_axes = _canonicalize_axes(x.ndim, self.axis) @@ -327,6 +336,7 @@ def __call__(self, x, use_running_average: Optional[bool] = None): axis_name=self.axis_name if not self.is_initializing() else None, axis_index_groups=self.axis_index_groups, use_fast_variance=self.use_fast_variance, + mask=mask, ) if not self.is_initializing(): @@ -644,8 +654,9 @@ def __call__(self, x): class SpectralNorm(Module): - """Spectral normalization. See: + """Spectral normalization. + See: - https://arxiv.org/abs/1802.05957 - https://arxiv.org/abs/1805.08318 - https://arxiv.org/abs/1809.11096 @@ -742,10 +753,10 @@ def __call__(self, *args, update_stats: bool, **kwargs): Args: *args: positional arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. - update_stats: if True, update the internal ``u`` vector and ``sigma`` value - after computing their updated values using power iteration. This will help - the power iteration method approximate the true singular value more - accurately over time. + update_stats: if True, update the internal ``u`` vector and ``sigma`` + value after computing their updated values using power iteration. This + will help the power iteration method approximate the true singular value + more accurately over time. **kwargs: keyword arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. @@ -756,7 +767,7 @@ def __call__(self, *args, update_stats: bool, **kwargs): def layer_forward(layer_instance): return layer_instance(*args, **kwargs) - return map_variables( + return transforms.map_variables( layer_forward, trans_in_fn=lambda vs: jax.tree_util.tree_map_with_path( functools.partial( @@ -786,7 +797,8 @@ def _spectral_normalize(self, path, vs, update_stats): value = jnp.asarray(vs) value_shape = value.shape - # Skip and return value if input is scalar, vector or if number of power iterations is less than 1 + # Skip and return value if input is scalar, vector or if number of power + # iterations is less than 1 if value.ndim <= 1 or self.n_steps < 1: return value # Handle higher-order tensors. @@ -844,9 +856,10 @@ def _spectral_normalize(self, path, vs, update_stats): u_var.value = u0 sigma_var.value = sigma - dtype = canonicalize_dtype(vs, u0, v0, sigma, dtype=self.dtype) + dtype = dtypes.canonicalize_dtype(vs, u0, v0, sigma, dtype=self.dtype) return jnp.asarray(value_bar, dtype) + class WeightNorm(Module): """L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf). @@ -878,7 +891,8 @@ def __call__(self, x): # l2-normalize all params of the second Dense layer x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x) x = nn.Dense(5)(x) - # l2-normalize all kernels in the Bar submodule and all params in the Baz submodule + # l2-normalize all kernels in the Bar submodule and all params in the + # Baz submodule x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x) return x @@ -923,10 +937,10 @@ def __call__(self, x): default, the trailing dimension is treated as the feature axis. variable_filter: An optional iterable that contains string items. The WeightNorm layer will selectively apply l2-normalization to the - ``layer_instance`` variables whose key path (delimited by '/') has a - match with ``variable_filter``. For example, ``variable_filter={'kernel'}`` - will only apply l2-normalization to variables whose key path contains - 'kernel'. By default, ``variable_filter={'kernel'}``. + ``layer_instance`` variables whose key path (delimited by '/') has a match + with ``variable_filter``. For example, ``variable_filter={'kernel'}`` will + only apply l2-normalization to variables whose key path contains 'kernel'. + By default, ``variable_filter={'kernel'}``. """ layer_instance: Module @@ -936,7 +950,9 @@ def __call__(self, x): use_scale: bool = True scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones feature_axes: Optional[Axes] = -1 - variable_filter: Optional[Iterable] = field(default_factory=lambda: {'kernel'}) + variable_filter: Optional[Iterable] = dataclasses.field( + default_factory=lambda: {'kernel'} + ) @compact def __call__(self, *args, **kwargs): @@ -957,11 +973,11 @@ def __call__(self, *args, **kwargs): def layer_forward(layer_instance): return layer_instance(*args, **kwargs) - return map_variables( + return transforms.map_variables( layer_forward, trans_in_fn=lambda vs: jax.tree_util.tree_map_with_path( - self._l2_normalize, - vs, + self._l2_normalize, + vs, ), init=self.is_initializing(), )(self.layer_instance) @@ -977,7 +993,11 @@ def _l2_normalize(self, path, vs): vs: variables to be l2-normalized """ value = jnp.asarray(vs) - str_path = self.layer_instance.name + '/' + '/'.join((dict_key.key for dict_key in path[1:])) + str_path = ( + self.layer_instance.name + + '/' + + '/'.join((dict_key.key for dict_key in path[1:])) + ) if self.variable_filter: for variable_name in self.variable_filter: if variable_name in str_path: @@ -990,7 +1010,9 @@ def _l2_normalize(self, path, vs): reduction_axes = tuple(i for i in range(value.ndim)) else: feature_axes = _canonicalize_axes(value.ndim, self.feature_axes) - reduction_axes = tuple(i for i in range(value.ndim) if i not in feature_axes) + reduction_axes = tuple( + i for i in range(value.ndim) if i not in feature_axes + ) feature_shape = [1] * value.ndim reduced_feature_shape = [] @@ -1003,10 +1025,13 @@ def _l2_normalize(self, path, vs): args = [vs] if self.use_scale: scale = self.param( - str_path + '/scale', self.scale_init, reduced_feature_shape, self.param_dtype + str_path + '/scale', + self.scale_init, + reduced_feature_shape, + self.param_dtype, ).reshape(feature_shape) value_bar *= scale args.append(scale) - dtype = canonicalize_dtype(*args, dtype=self.dtype) - return jnp.asarray(value_bar, dtype) \ No newline at end of file + dtype = dtypes.canonicalize_dtype(*args, dtype=self.dtype) + return jnp.asarray(value_bar, dtype) diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 1f5d929134..08f51857b5 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -15,18 +15,19 @@ """Tests for flax.linen.""" import copy -from absl.testing import absltest, parameterized from typing import Any +from absl.testing import absltest +from absl.testing import parameterized from flax import ids from flax import linen as nn from flax.training import train_state - import jax from jax import random import jax.numpy as jnp -import optax import numpy as np +import optax + # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() @@ -78,25 +79,6 @@ def test_avg_pool_no_batch(self, count_include_pad): ]).reshape((3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad) - @parameterized.parameters( - {'count_include_pad': True}, {'count_include_pad': False} - ) - def test_avg_pool_padding_same(self, count_include_pad): - x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1)) - pool = lambda x: nn.avg_pool( - x, (2, 2), padding='SAME', count_include_pad=count_include_pad - ) - y = pool(x) - if count_include_pad: - expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape( - (1, 2, 2, 1) - ) - else: - expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape( - (1, 2, 2, 1) - ) - np.testing.assert_allclose(y, expected_y) - def test_max_pool(self): x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) pool = lambda x: nn.max_pool(x, (2, 2)) @@ -148,51 +130,74 @@ def test_pooling_no_batch_dims(self): class NormalizationTest(parameterized.TestCase): - def test_batch_norm(self): + @parameterized.parameters({'test_mask': True}, {'test_mask': False}) + def test_batch_norm(self, test_mask): rng = random.key(0) - key1, key2 = random.split(rng) + key1, key2, key3 = random.split(rng, 3) x = random.normal(key1, (4, 3, 2)) + if test_mask: + m = random.randint( + key2, (4, 3, 1), minval=0, maxval=2, dtype=jnp.int32 + ).astype(jnp.bool_) + x = jnp.where(m, x, jnp.ones_like(x) * jnp.nan) + else: + m = None model_cls = nn.BatchNorm(momentum=0.9, use_running_average=False) - y, initial_params = model_cls.init_with_output(key2, x) + y, initial_params = model_cls.init_with_output(key3, x, mask=m) - mean = y.mean((0, 1)) - var = y.var((0, 1)) + mean = y.mean((0, 1), where=m) + var = y.var((0, 1), where=m) np.testing.assert_allclose(mean, np.array([0.0, 0.0]), atol=1e-4) np.testing.assert_allclose(var, np.array([1.0, 1.0]), rtol=1e-4) - - y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats']) + _, vars_out = model_cls.apply( + initial_params, x, mutable=['batch_stats'], mask=m + ) ema = vars_out['batch_stats'] np.testing.assert_allclose( - ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4 + ema['mean'], 0.1 * x.mean((0, 1), keepdims=False, where=m), atol=1e-4 ) np.testing.assert_allclose( - ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4 + ema['var'], + 0.9 + 0.1 * x.var((0, 1), keepdims=False, where=m), + rtol=1e-4, ) - def test_batch_norm_complex(self): + @parameterized.parameters({'test_mask': True}, {'test_mask': False}) + def test_batch_norm_complex(self, test_mask): rng = random.key(0) - key1, key2 = random.split(rng) + key1, key2, key3 = random.split(rng, 3) x = random.normal(key1, (4, 3, 2), dtype=jnp.complex64) + if test_mask: + m = random.randint( + key2, (4, 3, 1), minval=0, maxval=2, dtype=jnp.int32 + ).astype(jnp.bool_) + x = jnp.where(m, x, jnp.ones_like(x) * jnp.nan) + else: + m = None model_cls = nn.BatchNorm( momentum=0.9, use_running_average=False, dtype=jnp.complex64 ) - y, initial_params = model_cls.init_with_output(key2, x) + y, initial_params = model_cls.init_with_output(key3, x, mask=m) - mean = y.mean((0, 1)) - var = y.var((0, 1)) + mean = y.mean((0, 1), where=m) + var = y.var((0, 1), where=m) np.testing.assert_allclose(mean, np.array([0.0, 0.0]), atol=1e-4) np.testing.assert_allclose(var, np.array([1.0, 1.0]), rtol=1e-4) self.assertEqual(mean.dtype, jnp.complex64) - y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats']) + _, vars_out = model_cls.apply( + initial_params, x, mutable=['batch_stats'], mask=m + ) ema = vars_out['batch_stats'] np.testing.assert_allclose( - ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4 + ema['mean'], 0.1 * x.mean((0, 1), keepdims=False, where=m), atol=1e-4 ) np.testing.assert_allclose( - ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4 + ema['var'], + 0.9 + 0.1 * x.var((0, 1), keepdims=False, where=m), + rtol=1e-4, ) @parameterized.parameters( @@ -292,26 +297,37 @@ def __call__(self, x): key = random.key(0) model = Foo() x = random.normal(random.key(1), (2, 4)) - (y1, y2), variables = model.init_with_output(key, x) + (y1, y2), _ = model.init_with_output(key, x) np.testing.assert_allclose(y1, y2, rtol=1e-4) @parameterized.parameters( - {'model_index': 0, 'key_paths': {'Dense_1/kernel/u', 'Dense_1/kernel/sigma'}}, - {'model_index': 1, 'key_paths': {'Conv_0/kernel/u', 'Conv_0/kernel/sigma'}}, - {'model_index': 2, 'key_paths': {'MultiHeadDotProductAttention_0/key/bias/u', - 'MultiHeadDotProductAttention_0/key/kernel/u', - 'MultiHeadDotProductAttention_0/out/kernel/u', - 'MultiHeadDotProductAttention_0/query/bias/u', - 'MultiHeadDotProductAttention_0/query/kernel/u', - 'MultiHeadDotProductAttention_0/value/bias/u', - 'MultiHeadDotProductAttention_0/value/kernel/u', - 'MultiHeadDotProductAttention_0/key/bias/sigma', - 'MultiHeadDotProductAttention_0/key/kernel/sigma', - 'MultiHeadDotProductAttention_0/out/kernel/sigma', - 'MultiHeadDotProductAttention_0/query/bias/sigma', - 'MultiHeadDotProductAttention_0/query/kernel/sigma', - 'MultiHeadDotProductAttention_0/value/bias/sigma', - 'MultiHeadDotProductAttention_0/value/kernel/sigma'}} + { + 'model_index': 0, + 'key_paths': {'Dense_1/kernel/u', 'Dense_1/kernel/sigma'}, + }, + { + 'model_index': 1, + 'key_paths': {'Conv_0/kernel/u', 'Conv_0/kernel/sigma'}, + }, + { + 'model_index': 2, + 'key_paths': { + 'MultiHeadDotProductAttention_0/key/bias/u', + 'MultiHeadDotProductAttention_0/key/kernel/u', + 'MultiHeadDotProductAttention_0/out/kernel/u', + 'MultiHeadDotProductAttention_0/query/bias/u', + 'MultiHeadDotProductAttention_0/query/kernel/u', + 'MultiHeadDotProductAttention_0/value/bias/u', + 'MultiHeadDotProductAttention_0/value/kernel/u', + 'MultiHeadDotProductAttention_0/key/bias/sigma', + 'MultiHeadDotProductAttention_0/key/kernel/sigma', + 'MultiHeadDotProductAttention_0/out/kernel/sigma', + 'MultiHeadDotProductAttention_0/query/bias/sigma', + 'MultiHeadDotProductAttention_0/query/kernel/sigma', + 'MultiHeadDotProductAttention_0/value/bias/sigma', + 'MultiHeadDotProductAttention_0/value/kernel/sigma', + }, + }, ) def test_spectral_norm_train( self, model_index, key_paths @@ -323,6 +339,7 @@ def __call__(self, x, train): x = nn.SpectralNorm(nn.Dense(6))(x, update_stats=train) x = nn.Dense(4)(x) return x + class FooConv(nn.Module): @nn.compact def __call__(self, x, train): @@ -334,6 +351,7 @@ def __call__(self, x, train): x = x.reshape(1, -1) x = nn.Dense(4)(x) return x + class FooAttention(nn.Module): @nn.compact def __call__(self, x, train): @@ -387,7 +405,7 @@ def loss_fn(params): prev_loss = float('inf') for _ in range(10): state, loss = train_step(state, {'image': x, 'label': y}) - self.assertTrue(loss < prev_loss) + self.assertLess(loss, prev_loss) prev_loss = loss @parameterized.parameters( @@ -411,7 +429,7 @@ def __call__(self, x, train): variables = model_cls.init(random.PRNGKey(0), x, train=False) params, batch_stats = variables['params'], variables['batch_stats'] params = jax.tree_map(lambda x: 4 * jnp.eye(*x.shape), params) - logits, updates = model_cls.apply( + _, updates = model_cls.apply( {'params': params, 'batch_stats': batch_stats}, x=x, train=update_stats, @@ -445,26 +463,42 @@ def __call__(self, x, train): with self.assertRaisesRegex( ValueError, 'Input is 3D but error_on_non_matrix is True' ): - variables = model_cls.init(random.PRNGKey(0), x, train=False) + _ = model_cls.init(random.PRNGKey(0), x, train=False) else: - variables = model_cls.init(random.PRNGKey(0), x, train=False) + _ = model_cls.init(random.PRNGKey(0), x, train=False) @parameterized.parameters( {'feature_axes': -1, 'reduction_axes': 0, 'variable_filter': {'kernel'}}, {'feature_axes': 0, 'reduction_axes': 1, 'variable_filter': {'kernel'}}, - {'feature_axes': (0, 1), 'reduction_axes': (), 'variable_filter': {'kernel'}}, - {'feature_axes': (), 'reduction_axes': (0, 1), 'variable_filter': {'kernel'}}, - {'feature_axes': None, 'reduction_axes': (0, 1), 'variable_filter': {'kernel'}}, + { + 'feature_axes': (0, 1), + 'reduction_axes': (), + 'variable_filter': {'kernel'}, + }, + { + 'feature_axes': (), + 'reduction_axes': (0, 1), + 'variable_filter': {'kernel'}, + }, + { + 'feature_axes': None, + 'reduction_axes': (0, 1), + 'variable_filter': {'kernel'}, + }, {'feature_axes': 0, 'reduction_axes': (), 'variable_filter': {'bias'}}, - {'feature_axes': (), 'reduction_axes': -1, 'variable_filter': {'bias'}} + {'feature_axes': (), 'reduction_axes': -1, 'variable_filter': {'bias'}}, ) - def test_manual_weight_norm(self, feature_axes, reduction_axes, variable_filter): + def test_manual_weight_norm( + self, feature_axes, reduction_axes, variable_filter + ): class Foo(nn.Module): + @nn.compact def __call__(self, x): return nn.WeightNorm(nn.Dense(2, bias_init=nn.initializers.normal()), feature_axes=feature_axes, variable_filter=variable_filter)(x) + key1, key2 = jax.random.split(jax.random.key(1)) x = jax.random.normal(key1, (1, 3)) module = Foo() @@ -474,17 +508,26 @@ def __call__(self, x): kernel = v['params']['Dense_0']['kernel'] if 'kernel' in variable_filter: - kernel /= jnp.sqrt(jnp.sum(kernel**2, axis=reduction_axes, keepdims=True)) - kernel_scale = jnp.expand_dims(v['params']['WeightNorm_0']['Dense_0/kernel/scale'], axis=reduction_axes) + kernel /= jnp.sqrt( + jnp.sum(kernel**2, axis=reduction_axes, keepdims=True) + ) + kernel_scale = jnp.expand_dims( + v['params']['WeightNorm_0']['Dense_0/kernel/scale'], + axis=reduction_axes, + ) else: kernel_scale = 1 bias = v['params']['Dense_0']['bias'] if 'bias' in variable_filter: bias /= jnp.sqrt(jnp.sum(bias**2, axis=reduction_axes, keepdims=True)) - bias_scale = jnp.expand_dims(v['params']['WeightNorm_0']['Dense_0/bias/scale'], axis=reduction_axes) + bias_scale = jnp.expand_dims( + v['params']['WeightNorm_0']['Dense_0/bias/scale'], axis=reduction_axes + ) else: bias_scale = 1 - manual_out = jnp.dot(x, kernel_scale * kernel) + (bias_scale * bias).reshape(1, -1) + manual_out = jnp.dot(x, kernel_scale * kernel) + ( + bias_scale * bias + ).reshape(1, -1) self.assertTrue(jnp.allclose(out, manual_out)) @@ -513,10 +556,13 @@ def __call__(self, x): ) def test_weight_norm_variable_filter(self, variable_filters, key_paths): class Baz(nn.Module): + @nn.compact def __call__(self, x): return nn.Dense(2)(x) + class Bar(nn.Module): + @nn.compact def __call__(self, x): x = Baz()(x) @@ -526,41 +572,58 @@ def __call__(self, x): return x for variable_filter in variable_filters: + class Foo(nn.Module): + @nn.compact def __call__(self, x): return nn.WeightNorm(Bar(), variable_filter=variable_filter)(x) + v = Foo().init(jax.random.key(0), jnp.ones((1, 4))) self.assertEqual(key_paths, v['params']['WeightNorm_0'].keys()) @parameterized.parameters( {'model_index': 0, 'key_paths': {'Dense_1/kernel/scale'}}, {'model_index': 1, 'key_paths': {'Conv_0/kernel/scale'}}, - {'model_index': 2, 'key_paths': {'MultiHeadDotProductAttention_0/key/kernel/scale', - 'MultiHeadDotProductAttention_0/out/kernel/scale', - 'MultiHeadDotProductAttention_0/query/kernel/scale', - 'MultiHeadDotProductAttention_0/value/kernel/scale'}} + { + 'model_index': 2, + 'key_paths': { + 'MultiHeadDotProductAttention_0/key/kernel/scale', + 'MultiHeadDotProductAttention_0/out/kernel/scale', + 'MultiHeadDotProductAttention_0/query/kernel/scale', + 'MultiHeadDotProductAttention_0/value/kernel/scale', + }, + }, ) - def test_weight_norm_train( - self, model_index, key_paths - ): + def test_weight_norm_train(self, model_index, key_paths): class FooDense(nn.Module): + @nn.compact - def __call__(self, x,): + def __call__( + self, + x, + ): x = nn.Dense(8)(x) x = nn.WeightNorm(nn.Dense(6))(x) x = nn.Dense(4)(x) return x + class FooConv(nn.Module): + @nn.compact - def __call__(self, x,): + def __call__( + self, + x, + ): x = nn.Dense(9)(x) x = x.reshape((1, 3, 3)) x = nn.WeightNorm(nn.Conv(2, kernel_size=(2, 2)))(x) x = x.reshape(1, -1) x = nn.Dense(4)(x) return x + class FooAttention(nn.Module): + @nn.compact def __call__(self, x): a = nn.Dense(4)(x) @@ -603,7 +666,7 @@ def loss_fn(params): prev_loss = float('inf') for _ in range(10): state, loss = train_step(state, {'image': x, 'label': y}) - self.assertTrue(loss < prev_loss) + self.assertLess(loss, prev_loss) prev_loss = loss From fd8fd76a4af5307a61f85bac98feab9b26d60db8 Mon Sep 17 00:00:00 2001 From: Flax Team Date: Tue, 24 Oct 2023 10:00:23 -0700 Subject: [PATCH 25/27] Make a number of functions visible. PiperOrigin-RevId: 576175489 --- flax/linen/normalization.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 93338a32d7..c07c262688 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -32,8 +32,13 @@ Shape = Tuple[int, ...] Dtype = Any # this could be a real type? Axes = Union[int, Sequence[int]] + +field = dataclasses.field +canonicalize_dtype = dtypes.canonicalize_dtype compact = module.compact Module = module.Module +merge_param = module.merge_param +map_variables = transforms.map_variables def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: From d24db613d2ea889c6725aae5b32ad35f37efb253 Mon Sep 17 00:00:00 2001 From: kaixih Date: Tue, 24 Oct 2023 16:00:13 -0700 Subject: [PATCH 26/27] Update the fp8 support --- flax/linen/__init__.py | 6 +- flax/linen/fp8_ops.py | 145 ++++++++++++++----------------- flax/training/train_state.py | 66 ++++++-------- tests/linen/linen_linear_test.py | 139 ----------------------------- tests/linen/linen_test.py | 130 +++++++++++++++++++++++++++ 5 files changed, 224 insertions(+), 262 deletions(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 8d45daed02..ac7e76995f 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -69,11 +69,7 @@ make_causal_mask as make_causal_mask, ) from .combinators import Sequential as Sequential -from .fp8_ops import ( - compute_scale as fp8_compute_scale, - quantize_dequantize as fp8_quantize_dequantize, - Fp8DenseGeneralOp as Fp8DenseGeneralOp, -) +from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp from .initializers import ( ones_init as ones_init, ones as ones, diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 11e687cd9a..8199d58cb7 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -15,44 +15,39 @@ from functools import partial from flax.linen import initializers -from flax.linen.module import Module +from flax.linen import module from jax import custom_vjp from jax import lax from jax import numpy as jnp from jax import random -# Type annotations -Array = jnp.ndarray -Dtype = jnp.dtype -PRNGKey = jnp.ndarray +OVERWRITE_WITH_GRADIENT = '_overwrite_with_gradient' -class FP8Helper: - FP8_COLLECTION_NAME: str = "fp8_params" def get_fp8_max(fp8_dtype, out_dtype): assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2) return jnp.finfo(fp8_dtype).max.astype(out_dtype) + def quantize(x, q_dtype, scale, compute_dtype): - # We need to explicitly cast the max value to compute_dtype, otherwise the jax - # dtype promotion will cast the scaled_x to fp32 in the following ops, which - # would violate the fp8-matmul pattern matching. + # Explicitly cast the max values to the compute dtype to avoid unnecessary + # casting to FP32 during the subsequent math operations." dtype_max = get_fp8_max(q_dtype, compute_dtype) - scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape) - clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max) - return clipped_x.astype(q_dtype) + def dequantize(x, dq_dtype, scale): return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape) + def quantize_dequantize(x, q_dtype, scale, compute_dtype): qx = quantize(x, q_dtype, scale, compute_dtype) return dequantize(qx, x.dtype, scale) + def compute_scale(amax, scale, fp8_max, margin=0): """Default function to convert amax to scaling factor.""" # This function copied from the TransformerEngine is used to compute its @@ -66,38 +61,44 @@ def compute_scale(amax, scale, fp8_max, margin=0): sf = jnp.where(exp < 0, 1.0 / sf, sf) return 1.0 / sf + def compute_scale_and_amax_history(x, q_dtype, scale, amax_history): dtype_max = get_fp8_max(q_dtype, jnp.float32) - amax_update = jnp.max(jnp.abs(x)).astype(scale.dtype) - new_amax_history = \ - jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update) - - amax_from_history = jnp.max(new_amax_history, axis=0) + new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update) + amax_from_history = jnp.max(new_history, axis=0) new_scale = compute_scale(amax_from_history, scale, dtype_max) - return new_scale, new_amax_history + return new_scale, new_history + def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): qx = quantize_dequantize(x, q_dtype, scale, compute_dtype) - new_scale, new_amax_history = compute_scale_and_amax_history( - x, q_dtype, scale, amax_history) - return qx, new_scale, new_amax_history + new_scale, new_history = compute_scale_and_amax_history( + x, q_dtype, scale, amax_history + ) + return qx, new_scale, new_history + @partial(custom_vjp, nondiff_argnums=(0,)) def in_qdq(compute_dtype, inp, scale, amax_history): qin, _, _ = qdq_and_return( - inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype) + inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype + ) return qin + def in_qdq_fwd(compute_dtype, inp, scale, amax_history): - qin, new_scale, new_amax_history = qdq_and_return( - inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype) - return qin, (new_scale, new_amax_history) + qin, new_scale, new_history = qdq_and_return( + inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype + ) + return qin, (new_scale, new_history) + def in_qdq_bwd(compute_dtype, res, g): - new_scale, new_amax_history = res + new_scale, new_history = res q_g = g - return q_g, new_scale, new_amax_history + return q_g, new_scale, new_history + in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd) @@ -106,34 +107,23 @@ def in_qdq_bwd(compute_dtype, res, g): def out_qdq(compute_dtype, out, scale, amax_history): return out + def out_qdq_fwd(compute_dtype, out, scale, amax_history): return out, (scale, amax_history) + def out_qdq_bwd(compute_dtype, res, g): scale, amax_history = res - q_g, new_scale, new_amax_history = qdq_and_return( - g, jnp.float8_e5m2, scale, amax_history, compute_dtype) - return q_g, new_scale, new_amax_history - -out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) - -def fp8_dot_general(lhs, rhs, dimension_numbers, precision, compute_dtype, - lhs_scale, lhs_amax_history, rhs_scale, rhs_amax_history, - dout_scale, dout_amax_history): - """Perform dot_general. """ + q_g, new_scale, new_history = qdq_and_return( + g, jnp.float8_e5m2, scale, amax_history, compute_dtype + ) + return q_g, new_scale, new_history - lhs_qdq = in_qdq(compute_dtype, lhs, lhs_scale, lhs_amax_history) - rhs_qdq = in_qdq(compute_dtype, rhs, rhs_scale, rhs_amax_history) - - output_qdq = lax.dot_general(lhs_qdq, rhs_qdq, dimension_numbers, precision) - - out = out_qdq(compute_dtype, output_qdq, dout_scale, dout_amax_history) - - return out +out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) -class Fp8DenseGeneralOp(Module): +class Fp8DotGeneralOp(module.Module): amax_history_length: int = 1024 def setup(self) -> None: @@ -151,47 +141,46 @@ def setup(self) -> None: ) self.input_amax_history = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'input_amax_history', - *amax_history_args) + OVERWRITE_WITH_GRADIENT, 'input_amax_history', *amax_history_args) self.kernel_amax_history = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'kernel_amax_history', - *amax_history_args) + OVERWRITE_WITH_GRADIENT, 'kernel_amax_history', *amax_history_args) self.output_grad_amax_history = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'output_grad_amax_history', - *amax_history_args) + OVERWRITE_WITH_GRADIENT, 'output_grad_amax_history', *amax_history_args) self.input_scale = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'input_scale', - *scale_args) + OVERWRITE_WITH_GRADIENT, 'input_scale', *scale_args) self.kernel_scale = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'kernel_scale', - *scale_args) + OVERWRITE_WITH_GRADIENT, 'kernel_scale', *scale_args) self.output_grad_scale = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'output_grad_scale', - *scale_args) + OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args) - def __call__(self, *args, **kwargs) -> Array: + def __call__(self, *args, **kwargs) -> jnp.ndarray: assert len(args) == 3 - inputs = args[0] - kernel = args[1] + x = args[0] + k = args[1] dimension_numbers = args[2] precision = kwargs['precision'] - comp_dtype = kernel.dtype - inputs = jnp.asarray(inputs, comp_dtype) - - out = fp8_dot_general(inputs, kernel, dimension_numbers, precision, - comp_dtype, self.input_scale.value, - self.input_amax_history.value, - self.kernel_scale.value, self.kernel_amax_history.value, - self.output_grad_scale.value, - self.output_grad_amax_history.value) - return out + + # Use the `k.dtype` since it aligns with the `dtype` of its layers, + # namely, the computation data type. + comp_dtype = k.dtype + x = jnp.asarray(x, comp_dtype) + + x_qdq = in_qdq( + comp_dtype, x, self.input_scale.value, self.input_amax_history.value + ) + k_qdq = in_qdq( + comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value + ) + y_qdq = lax.dot_general(x_qdq, k_qdq, dimension_numbers, precision) + y = out_qdq( + comp_dtype, + y_qdq, + self.output_grad_scale.value, + self.output_grad_amax_history.value + ) + + return y diff --git a/flax/training/train_state.py b/flax/training/train_state.py index 81c7a1ca4f..9570ebe51b 100644 --- a/flax/training/train_state.py +++ b/flax/training/train_state.py @@ -16,6 +16,7 @@ from flax import core from flax import struct +from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT import optax @@ -71,8 +72,27 @@ def apply_gradients(self, *, grads, **kwargs): and `opt_state` updated by applying `grads`, and additional attributes replaced as specified by `kwargs`. """ - updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) - new_params = optax.apply_updates(self.params, updates) + if OVERWRITE_WITH_GRADIENT in grads: + grads_with_opt = grads['params'] + params_with_opt = self.params['params'] + else: + grads_with_opt = grads + params_with_opt = self.params + + updates, new_opt_state = self.tx.update( + grads_with_opt, self.opt_state, params_with_opt + ) + new_params_with_opt = optax.apply_updates(params_with_opt, updates) + + # As implied by the OWG name, the gradients are used directly to update the + # parameters. + if OVERWRITE_WITH_GRADIENT in grads: + new_params = { + 'params': new_params_with_opt, + OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT] + } + else: + new_params = new_params_with_opt return self.replace( step=self.step + 1, params=new_params, @@ -83,45 +103,11 @@ def apply_gradients(self, *, grads, **kwargs): @classmethod def create(cls, *, apply_fn, params, tx, **kwargs): """Creates a new instance with `step=0` and initialized `opt_state`.""" - opt_state = tx.init(params) - return cls( - step=0, - apply_fn=apply_fn, - params=params, - tx=tx, - opt_state=opt_state, - **kwargs, - ) - -class Fp8TrainState(TrainState): - """Customized train state for Fp8.""" - - def apply_gradients(self, *, grads, **kwargs): - assert 'fp8_params' in grads - updates, new_opt_state = self.tx.update(grads['params'], self.opt_state, - self.params['params']) - new_non_fp8_params = optax.apply_updates(self.params['params'], updates) - - # self.param is structured as - # {'param': {'kernel:...,'}, 'fp8_params': {...}}. For the fp8 variables - # in the fp8-params collection, we will simply replace them with their - # grads, because their grads are actually new values defined in the - # custom_vjp functions. - new_params = {'params': new_non_fp8_params, - 'fp8_params': grads['fp8_params']} - - return self.replace( - step=self.step + 1, - params=new_params, - opt_state=new_opt_state, - **kwargs, + # We exclude OWG params when present because they do not need opt states. + params_with_opt = ( + params['params'] if OVERWRITE_WITH_GRADIENT in params else params ) - - @classmethod - def create(cls, *, apply_fn, params, tx, **kwargs): - assert 'fp8_params' in params - opt_state = tx.init(params['params']) - + opt_state = tx.init(params_with_opt) return cls( step=0, apply_fn=apply_fn, diff --git a/tests/linen/linen_linear_test.py b/tests/linen/linen_linear_test.py index 899cb5ea71..bbfacd325e 100644 --- a/tests/linen/linen_linear_test.py +++ b/tests/linen/linen_linear_test.py @@ -19,10 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized -import optax - from flax import linen as nn -from flax.training import train_state import jax from jax import random @@ -1018,142 +1015,6 @@ def __call__(self, x): ) self.assertEqual(y.shape, (2, 8, 6)) - def test_fp8_dot_general_cls_injection(self): - # Used to cast the inputs to be representable in FP8, so that the difference - # of the results from the original gemm and fp8 gemm is small. - cast_to_representable = functools.partial(nn.fp8_quantize_dequantize, - scale=jnp.ones((1,)), - compute_dtype=jnp.float32) - - init_key, random_key = jax.random.split( - jax.random.PRNGKey(seed=123), 2) - - x = jax.random.uniform(random_key, (16, 32)) - x = cast_to_representable(x, jnp.float8_e4m3fn) - dy = jax.random.uniform(random_key, (16, 64)) - dy = cast_to_representable(dy, jnp.float8_e5m2) - def run(fp8_injection, expected_shapes): - p = nn.DenseGeneral(features=64, name='dense') - if fp8_injection: - p.dot_general_cls=nn.Fp8DenseGeneralOp - y, initial_vars = p.init_with_output(init_key, x) - var_shapes = jax.tree_util.tree_map(jnp.shape, initial_vars) - self.assertEqual(var_shapes, expected_shapes) - - def _train(variables, x): - y = p.apply(variables, x) - loss = y * dy - return jnp.mean(loss) - train_fn = jax.jit(jax.value_and_grad(_train, argnums=[0, 1])) - outputs, grads = train_fn(initial_vars, x) - return outputs, grads - - expected_shapes_original = { - 'params': {'kernel': (32, 64), 'bias': (64,)}, - } - expected_shapes_new = { - 'params': {'kernel': (32, 64), 'bias': (64,)}, - 'fp8_params': { - 'Fp8DenseGeneralOp_0': {'input_amax_history': (1024,), - 'kernel_amax_history': (1024,), - 'output_grad_amax_history': (1024,), - 'input_scale': (1,), - 'kernel_scale': (1,), - 'output_grad_scale': (1,), }}, - } - - output1a, output1b = run(False, expected_shapes_original) - output2a, output2b = run(True, expected_shapes_new) - dw1, dw2 = output1b[0]['params']['kernel'], output2b[0]['params']['kernel'] - dx1, dx2 = output1b[1], output2b[1] - - np.testing.assert_allclose(output1a, output2a, atol=1e-02) - np.testing.assert_allclose(dw1, dw2, atol=1e-04) - np.testing.assert_allclose(dx1, dx2, atol=1e-04) - - def test_fp8_with_train_state(self): - x = random.uniform(random.PRNGKey(1), (16, 16), dtype=jnp.float32) - dense = nn.DenseGeneral(features=32, use_bias=True, - dot_general_cls=nn.Fp8DenseGeneralOp) - key = random.PRNGKey(0) - variables = dense.init(key, x) - - opt = optax.adam(learning_rate=.1) - state = train_state.Fp8TrainState.create(params=variables, tx=opt, - apply_fn=dense.apply) - - def roll_and_update(amax_h, update): - return jnp.roll(amax_h, shift=-1, axis=0).at[0].set(update) - - def _train_loss(state, x, dy): - def loss_fn(vars): - y = state.apply_fn(vars, x) - loss = y * dy.astype(y.dtype) - return jnp.sum(loss) - - grad_fn = jax.grad(loss_fn) - grads = grad_fn(state.params) - - state = state.apply_gradients(grads=grads) - return state - - train_fn = jax.jit(_train_loss) - - amax_history_x = jnp.zeros((1024, )) - amax_history_k = jnp.zeros((1024, )) - amax_history_dy = jnp.zeros((1024, )) - scale_x = jnp.ones(()) - scale_k = jnp.ones(()) - scale_dy = jnp.ones(()) - fp8_e4m3_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) - fp8_e5m2_max = jnp.finfo(jnp.float8_e5m2).max.astype(jnp.float32) - for _ in range(5): - x = random.normal(random.PRNGKey(1), (16, 16), dtype=jnp.float32) - dy = random.normal(random.PRNGKey(1), (16, 32), dtype=jnp.float32) - - amax_history_x = roll_and_update(amax_history_x, jnp.max(jnp.abs(x))) - amax_history_k = roll_and_update( - amax_history_k, - jnp.max(jnp.abs(state.params['params']['kernel']))) - amax_history_dy = roll_and_update(amax_history_dy, jnp.max(jnp.abs(dy))) - - amax_from_history_x = jnp.max(amax_history_x, axis=0) - amax_from_history_k = jnp.max(amax_history_k, axis=0) - amax_from_history_dy = jnp.max(amax_history_dy, axis=0) - scale_x = nn.fp8_compute_scale(amax_from_history_x, scale_x, - fp8_e4m3_max) - scale_k = nn.fp8_compute_scale(amax_from_history_k, scale_k, fp8_e4m3_max) - scale_dy = nn.fp8_compute_scale(amax_from_history_dy, scale_dy, - fp8_e5m2_max) - - state = train_fn(state, x, dy) - - rtol, atol = 0.001, 0.001 - fp8_vars = state.params['fp8_params'] - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0']['input_amax_history'], - amax_history_x, rtol=rtol, atol=atol) - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0']['kernel_amax_history'], - amax_history_k, rtol=rtol, atol=atol) - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0'] - ['output_grad_amax_history'], - amax_history_dy, rtol=rtol, atol=atol) - - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0'] - ['input_scale'][0], - scale_x) - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0'] - ['kernel_scale'][0], - scale_k) - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0'] - ['output_grad_scale'][0], - scale_dy) - def test_non_final_axes(self): class Foo(nn.Module): diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 08f51857b5..135782e629 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -15,12 +15,14 @@ """Tests for flax.linen.""" import copy +import functools from typing import Any from absl.testing import absltest from absl.testing import parameterized from flax import ids from flax import linen as nn +from flax.linen import fp8_ops from flax.training import train_state import jax from jax import random @@ -868,5 +870,133 @@ def test_hashable(self): self.assertNotEqual(hash(id1), hash(id1dc)) +class Fp8Test(absltest.TestCase): + + def test_fp8_dot_general_injection(self): + # Used to cast the inputs to be representable in FP8, so that the difference + # of the results from the original gemm and fp8 gemm is small. + cast_to_representable = functools.partial(fp8_ops.quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=jnp.float32) + + init_key, random_key = random.split(random.PRNGKey(seed=123), 2) + x = cast_to_representable( + random.uniform(random_key, (16, 32)), jnp.float8_e4m3fn) + dy = cast_to_representable( + random.uniform(random_key, (16, 64)), jnp.float8_e5m2) + + def run(fp8_injection, expected_shapes): + p = nn.DenseGeneral(features=64, name='dense') + if fp8_injection: + p.dot_general_cls=nn.Fp8DotGeneralOp + y, initial_vars = p.init_with_output(init_key, x) + var_shapes = jax.tree_util.tree_map(jnp.shape, initial_vars) + self.assertEqual(var_shapes, expected_shapes) + + def _train(variables, x): + y = p.apply(variables, x) + loss = y * dy + return jnp.mean(loss) + + train_fn = jax.jit(jax.value_and_grad(_train, argnums=[0, 1])) + outputs, grads = train_fn(initial_vars, x) + return outputs, grads + + expected_shapes_original = { + 'params': {'kernel': (32, 64), 'bias': (64,)}, + } + expected_shapes_new = { + 'params': {'kernel': (32, 64), 'bias': (64,)}, + fp8_ops.OVERWRITE_WITH_GRADIENT: { + 'Fp8DotGeneralOp_0': {'input_amax_history': (1024,), + 'kernel_amax_history': (1024,), + 'output_grad_amax_history': (1024,), + 'input_scale': (1,), + 'kernel_scale': (1,), + 'output_grad_scale': (1,), }}, + } + + output1a, output1b = run(False, expected_shapes_original) + output2a, output2b = run(True, expected_shapes_new) + dw1, dw2 = output1b[0]['params']['kernel'], output2b[0]['params']['kernel'] + dx1, dx2 = output1b[1], output2b[1] + + np.testing.assert_allclose(output1a, output2a, atol=1e-02) + np.testing.assert_allclose(dw1, dw2, atol=1e-04) + np.testing.assert_allclose(dx1, dx2, atol=1e-04) + + def test_fp8_train_state(self): + key, init_key, random_key = random.split(random.PRNGKey(seed=123), 3) + x = random.uniform(random_key, (16, 16), dtype=jnp.float32) + dense = nn.DenseGeneral(features=32, use_bias=True, + dot_general_cls=nn.Fp8DotGeneralOp) + variables = dense.init(init_key, x) + opt = optax.adam(learning_rate=.1) + state = train_state.TrainState.create( + params=variables, tx=opt, apply_fn=dense.apply + ) + + def _roll_and_update(amax_h, update): + return jnp.roll(amax_h, shift=-1, axis=0).at[0].set(update) + + def _train_loss(state, x, dy): + def loss_fn(vars): + y = state.apply_fn(vars, x) + loss = y * dy.astype(y.dtype) + return jnp.sum(loss) + grad_fn = jax.grad(loss_fn) + grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + return state + + train_fn = jax.jit(_train_loss) + + scale_x, amax_history_x = jnp.ones(()), jnp.zeros((1024,)) + scale_k, amax_history_k = jnp.ones(()), jnp.zeros((1024,)) + scale_g, amax_history_g = jnp.ones(()), jnp.zeros((1024,)) + e4m3_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) + e5m2_max = jnp.finfo(jnp.float8_e5m2).max.astype(jnp.float32) + + for _ in range(5): + key, random_key = random.split(key, 2) + x = random.normal(random_key, (16, 16), dtype=jnp.float32) + g = random.normal(random_key, (16, 32), dtype=jnp.float32) + k = state.params['params']['kernel'] + + # Manually compute the expected amax history and scaling factors. + amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x))) + amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(k))) + amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g))) + amax_from_history_x = jnp.max(amax_history_x, axis=0) + amax_from_history_k = jnp.max(amax_history_k, axis=0) + amax_from_history_g = jnp.max(amax_history_g, axis=0) + scale_x = fp8_ops.compute_scale(amax_from_history_x, scale_x, e4m3_max) + scale_k = fp8_ops.compute_scale(amax_from_history_k, scale_k, e4m3_max) + scale_g = fp8_ops.compute_scale(amax_from_history_g, scale_g, e5m2_max) + + state = train_fn(state, x, g) + + rtol, atol = 0.001, 0.001 + fp8_vars = ( + state.params[fp8_ops.OVERWRITE_WITH_GRADIENT]['Fp8DotGeneralOp_0'] + ) + np.testing.assert_allclose( + fp8_vars['input_amax_history'], amax_history_x, rtol=rtol, atol=atol, + ) + np.testing.assert_allclose( + fp8_vars['kernel_amax_history'], amax_history_k, rtol=rtol, atol=atol, + ) + np.testing.assert_allclose( + fp8_vars['output_grad_amax_history'], + amax_history_g, + rtol=rtol, + atol=atol, + ) + + np.testing.assert_allclose(fp8_vars['input_scale'][0], scale_x) + np.testing.assert_allclose(fp8_vars['kernel_scale'][0], scale_k) + np.testing.assert_allclose(fp8_vars['output_grad_scale'][0], scale_g) + + if __name__ == '__main__': absltest.main() From 5557649f2cb200818ab9e3a8c636aee3e69f2028 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Wed, 25 Oct 2023 18:55:20 -0700 Subject: [PATCH 27/27] Updated docstring for `axis_name` arg to say it's only used for `pmap` PiperOrigin-RevId: 576709933 --- flax/core/lift.py | 22 ++++++++++++---------- flax/linen/normalization.py | 23 +++++++++++++++++++---- flax/linen/transforms.py | 26 ++++++++++++++------------ flax/training/dynamic_scale.py | 14 +++++++++----- 4 files changed, 54 insertions(+), 31 deletions(-) diff --git a/flax/core/lift.py b/flax/core/lift.py index 4dba2655c2..42d7acbf18 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -641,18 +641,20 @@ def vmap( Args: fn: the function to be transformed. - variable_axes: the variable collections that are lifted into the - batching transformation. Use `None` to indicate a broadcasted - collection or an integer to map over an axis. - split_rngs: Split PRNG sequences will be different for each index - of the batch dimension. Unsplit PRNGs will be broadcasted. + variable_axes: the variable collections that are lifted into the batching + transformation. Use `None` to indicate a broadcasted collection or an + integer to map over an axis. + split_rngs: Split PRNG sequences will be different for each index of the + batch dimension. Unsplit PRNGs will be broadcasted. in_axes: Specifies the mapping of the input arguments (see `jax.vmap). out_axes: Specifies the mapping of the return value (see `jax.vmap). - axis_size: Specifies the size of the batch axis. This only needs - to be specified if it cannot be derived from the input arguments. - axis_name: Specifies a name for the batch axis. Can be used together - with parallel reduction primitives (e.g. `jax.lax.pmean`, - `jax.lax.ppermute`, etc.) + axis_size: Specifies the size of the batch axis. This only needs to be + specified if it cannot be derived from the input arguments. + axis_name: Specifies a name for the batch axis. Can be used together with + parallel reduction primitives (e.g. `jax.lax.pmean`, `jax.lax.ppermute`, + etc.). Note, this is only used for pmap and shmap. For SPMD jit, you do + not need to manually synchronize. Just make sure that the axes are + correctly annotated and XLA:SPMD will insert the necessary collectives. spmd_axis_name: Axis name added to any pjit sharding constraints appearing in `fn`. See also https://github.com/google/flax/blob/main/flax/linen/partitioning.py. diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index c07c262688..daf79f0981 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -83,7 +83,10 @@ def _compute_stats( axes: The axes in ``x`` to compute mean and variance statistics for. dtype: Optional dtype specifying the minimal precision. Statistics are always at least float32 for stability (default: dtype of x). - axis_name: Optional name for the pmapped axis to compute mean over. + axis_name: Optional name for the pmapped axis to compute mean over. Note, + this is only used for pmap and shard map. For SPMD jit, you do not need to + manually synchronize. Just make sure that the axes are correctly annotated + and XLA:SPMD will insert the necessary collectives. axis_index_groups: Optional axis indices. use_mean: If true, calculate the mean from the input and use it when computing the variance. If false, set the mean to zero and compute the @@ -269,6 +272,9 @@ class BatchNorm(Module): scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). + Note, this is only used for pmap and shard map. For SPMD jit, you do not + need to manually synchronize. Just make sure that the axes are correctly + annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the @@ -390,7 +396,10 @@ class LayerNorm(Module): axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the - array being normalized is sharded across devices within a pmap. + array being normalized is sharded across devices within a pmap or shard + map. For SPMD jit, you do not need to manually synchronize. Just make sure + that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the @@ -481,7 +490,10 @@ class RMSNorm(Module): axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the - array being normalized is sharded across devices within a pmap. + array being normalized is sharded across devices within a pmap or shard + map. For SPMD jit, you do not need to manually synchronize. Just make sure + that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the @@ -561,7 +573,10 @@ class GroupNorm(Module): axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the - array being normalized is sharded across devices within a pmap. + array being normalized is sharded across devices within a pmap or shard + map. For SPMD jit, you do not need to manually synchronize. Just make sure + that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index adaf7822fc..272ff45668 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -537,20 +537,22 @@ def vmap( RNG must also be shared. Args: - target: a ``Module`` or a function taking a ``Module`` - as its first argument. - variable_axes: the variable collections that are lifted into the - batching transformation. Use `None` to indicate a broadcasted - collection or an integer to map over an axis. - split_rngs: Split PRNG sequences will be different for each index - of the batch dimension. Unsplit PRNGs will be broadcasted. + target: a ``Module`` or a function taking a ``Module`` as its first + argument. + variable_axes: the variable collections that are lifted into the batching + transformation. Use `None` to indicate a broadcasted collection or an + integer to map over an axis. + split_rngs: Split PRNG sequences will be different for each index of the + batch dimension. Unsplit PRNGs will be broadcasted. in_axes: Specifies the mapping of the input arguments (see `jax.vmap`). out_axes: Specifies the mapping of the return value (see `jax.vmap`). - axis_size: Specifies the size of the batch axis. This only needs - to be specified if it cannot be derived from the input arguments. - axis_name: Specifies a name for the batch axis. Can be used together - with parallel reduction primitives (e.g. `jax.lax.pmean`, - `jax.lax.ppermute`, etc.) + axis_size: Specifies the size of the batch axis. This only needs to be + specified if it cannot be derived from the input arguments. + axis_name: Specifies a name for the batch axis. Can be used together with + parallel reduction primitives (e.g. `jax.lax.pmean`, `jax.lax.ppermute`, + etc.). Note, this is only used for pmap and shard map. For SPMD jit, you + do not need to manually synchronize. Just make sure that the axes are + correctly annotated and XLA:SPMD will insert the necessary collectives. methods: If `target` is a `Module`, the methods of `Module` to vmap over. spmd_axis_name: Axis name added to any pjit sharding constraints appearing in `fn`. See also diff --git a/flax/training/dynamic_scale.py b/flax/training/dynamic_scale.py index 3ed93004c8..b7b329a7f2 100644 --- a/flax/training/dynamic_scale.py +++ b/flax/training/dynamic_scale.py @@ -101,16 +101,20 @@ def value_and_grad( Args: fun: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers. - It should return a scalar (which includes arrays with shape ``()`` - but not arrays with shape ``(1,)`` etc.) + It should return a scalar (which includes arrays with shape ``()`` but + not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function - to be differentiated and the second element is auxiliary data. - Default False. + to be differentiated and the second element is auxiliary data. Default + False. axis_name: If an axis is given the gradients will be averaged across - replicas (default: None). + replicas (default: None). Note, this is only used for pmap and shard + map. For SPMD jit, you do not need to manually synchronize. Just make + sure that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. + Returns: A function that takes the same arguments as `fun` and returns a DynamicScaleResult