diff --git a/.github/analytics/get_repo_metrics.py b/.github/analytics/get_repo_metrics.py index 600270a0df..d936c1c7bb 100644 --- a/.github/analytics/get_repo_metrics.py +++ b/.github/analytics/get_repo_metrics.py @@ -15,7 +15,6 @@ import json import os from datetime import datetime -from pathlib import Path from typing import Callable, List from absl import app, flags diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 803e5ffae2..287b939df8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,3 +34,7 @@ repos: --extra-keys, "metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab", ] +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.292 + hooks: + - id: ruff diff --git a/.readthedocs.yml b/.readthedocs.yml index e086a43677..fa87e6d31f 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -18,7 +18,7 @@ sphinx: formats: - htmlzip - epub - - pdf + # - pdf # Optionally set the version of Python and requirements required to build your docs python: diff --git a/CHANGELOG.md b/CHANGELOG.md index 43c0fa18d9..5032ab925c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,9 @@ vNext - - - -- +- Re-factored `MultiHeadDotProductAttention`'s call method signatur, by adding +`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic` +to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389). - - - diff --git a/docs/_ext/codediff.py b/docs/_ext/codediff.py index eabc634fd8..8a8e6f3fb9 100644 --- a/docs/_ext/codediff.py +++ b/docs/_ext/codediff.py @@ -26,7 +26,6 @@ In order to highlight a line of code, append "#!" to it. """ -import itertools from typing import List, Tuple from docutils import nodes diff --git a/docs/api_reference/flax.cursor.rst b/docs/api_reference/flax.cursor.rst index 56ace5bc14..073f06ee1c 100644 --- a/docs/api_reference/flax.cursor.rst +++ b/docs/api_reference/flax.cursor.rst @@ -12,7 +12,7 @@ To illustrate, consider the example below:: import dataclasses from typing import Any - @dataclasses.dataclass + @dataclasses.dataclass(frozen=True) class A: x: Any diff --git a/docs/api_reference/flax.linen/layers.rst b/docs/api_reference/flax.linen/layers.rst index 87bfff0fc9..d4f28b21c7 100644 --- a/docs/api_reference/flax.linen/layers.rst +++ b/docs/api_reference/flax.linen/layers.rst @@ -52,10 +52,18 @@ Normalization :module: flax.linen :class: GroupNorm +.. flax_module:: + :module: flax.linen + :class: RMSNorm + .. flax_module:: :module: flax.linen :class: SpectralNorm +.. flax_module:: + :module: flax.linen + :class: WeightNorm + Combinators ------------------------ @@ -132,6 +140,7 @@ Recurrent GroupNorm RMSNorm SpectralNorm + WeightNorm Sequential Dropout SelfAttention diff --git a/docs/conf.py b/docs/conf.py index 6ca260ffef..262ef67b0d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -37,7 +37,6 @@ sys.path.append(os.path.abspath('./_ext')) # patch sphinx -import docs.conf_sphinx_patch # -- Project information ----------------------------------------------------- project = 'Flax' @@ -134,6 +133,7 @@ nb_execution_excludepatterns = [ 'getting_started.ipynb', # <-- times out 'optax_update_guide.ipynb', # <-- requires flax<=0.5.3 + 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False diff --git a/docs/examples_community_examples.rst b/docs/examples_community_examples.rst index 0da9245c3d..079568c9a7 100644 --- a/docs/examples_community_examples.rst +++ b/docs/examples_community_examples.rst @@ -56,10 +56,6 @@ Examples - `@vasudevgupta7 `__ - Question-Answering - https://arxiv.org/abs/2007.14062 - * - `Bayesian Networks with BlackJAX `__ - - `@rlouf `__ - - Bayesian Inference, SGMCMC - - https://arxiv.org/abs/1402.4102 * - `DCGAN `__ - `@bkkaggle `__ - Image Synthesis diff --git a/docs/faq.rst b/docs/faq.rst new file mode 100644 index 0000000000..f6b0d30b16 --- /dev/null +++ b/docs/faq.rst @@ -0,0 +1,38 @@ +Frequently Asked Questions (FAQ) +================================ + +This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in `GitHub Discussions `__. + +Where to search for an answer to a Flax-related question? +********************************************************* + +There are a number of official Flax resources to search for information: + +- `Flax Documentation on ReadTheDocs `__ (this site): Use the `search bar `__ or the table of contents on the left-hand side. +- `google/flax GitHub Discussions `__: Search for an existing topic or start a new one. If you can't find what you're looking for, feel free to ask the Flax team or community a question. +- `google/flax GitHub Issues `__: Use the search bar to look for an existing issue or a feature request, or start a new one. + +How to take the derivative with respect to an intermediate value (using :code:`Module.perturb`)? +************************************************************************************************ + +To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use :meth:`flax.linen.Module.perturb`. You define a zero-value :class:`flax.linen.Module` "perturbation" parameter – :code:`perturb(...)` – in the forward pass with the same shape as the intermediate activation, define the loss function with :code:`'perturbations'` as an added standalone argument, perform a JAX derivative operation with :code:`jax.grad` on the perturbation argument. + +For full examples and detailed documentation, go to: + +- The :meth:`flax.linen.Module.perturb` API docs +- The `Extracting gradients of intermediate values `_ guide +- `Flax GitHub Discussions #1152 `__ + +Is Flax Linen :code:`remat_scan()` the same as :code:`scan(remat(...))`? +************************************************************************ + +Flax :code:`remat_scan()` (:meth:`flax.linen.remat_scan()`) and :code:`scan(remat(...))` (:meth:`flax.linen.scan` over :meth:`flax.linen.remat`) are not the same, and :code:`remat_scan()` is limited in cases it supports. Namely, :code:`remat_scan()` treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use :code:`scan(remat(...))`, as typically you would need the extra parameters, such as ``in_axes`` (for input array axes) or ``out_axes`` (output array axes), which :meth:`flax.linen.remat_scan` does not expose. + +What are the recommended training loop libraries? +************************************************* + +Consider using CLU (Common Loop Utils) `google/CommonLoopUtils `__. To get started, go to this `CLU Synopsis Colab `__. You can find answers to common questions about CLU with Flax on `google/flax GitHub Discussions `__. + +Check out the official `google/flax Examples `__ for examples of using the training loop with (CLU) metrics. For example, this is `Flax ImageNet's train.py `__. + +For computer vision research, consider `google-research/scenic `__. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the `README page on GitHub `__. \ No newline at end of file diff --git a/docs/guides/convert_pytorch_to_flax.rst b/docs/guides/convert_pytorch_to_flax.rst index ff4e58acd3..43283d448d 100644 --- a/docs/guides/convert_pytorch_to_flax.rst +++ b/docs/guides/convert_pytorch_to_flax.rst @@ -85,7 +85,7 @@ Convolutions and FC Layers We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then reshaped to [N, C * H * W] before being fed to the fc layers. -When we port our weights from PyToch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. +When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W]. Consider this PyTorch model: @@ -266,6 +266,40 @@ while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolut implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_ that contains an implementation. +To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the ``transpose_kernel`` arg in Flax's +``nn.ConvTranspose`` layer. + +.. testcode:: + + # padding is inverted + torch_padding = 0 + flax_padding = 1 - torch_padding + + t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=torch_padding) + + kernel = t_conv.weight.detach().cpu().numpy() + bias = t_conv.bias.detach().cpu().numpy() + + # [inC, outC, kH, kW] -> [kH, kW, outC, inC] + kernel = jnp.transpose(kernel, (2, 3, 1, 0)) + + key = random.key(0) + x = random.normal(key, (1, 6, 6, 3)) + + variables = {'params': {'kernel': kernel, 'bias': bias}} + # ConvTranspose expects the kernel to be [kH, kW, inC, outC], + # but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead + j_conv = nn.ConvTranspose(features=4, kernel_size=(2, 2), padding=flax_padding, transpose_kernel=True) + j_out = j_conv.apply(variables, x) + + # [N, H, W, C] -> [N, C, H, W] + t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) + t_out = t_conv(t_x) + # [N, C, H, W] -> [N, H, W, C] + t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) + np.testing.assert_almost_equal(j_out, t_out, decimal=6) + + .. _`pull request`: https://github.com/google/jax/pull/5772 .. |nn.ConvTranspose| replace:: ``nn.ConvTranspose`` diff --git a/docs/guides/flax_on_pjit.ipynb b/docs/guides/flax_on_pjit.ipynb index 6ec633772a..d7705f34f4 100644 --- a/docs/guides/flax_on_pjit.ipynb +++ b/docs/guides/flax_on_pjit.ipynb @@ -190,7 +190,7 @@ "\n", "1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters.\n", "\n", - "2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n", + "2. Apply [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n", "\n", " * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless." ] @@ -1281,7 +1281,7 @@ "\n", "* **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", - "* **logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n", + "* **Logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n", "\n", "* **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful." ] diff --git a/docs/guides/flax_on_pjit.md b/docs/guides/flax_on_pjit.md index 3b53875522..81e7b4d623 100644 --- a/docs/guides/flax_on_pjit.md +++ b/docs/guides/flax_on_pjit.md @@ -119,7 +119,7 @@ To shard the parameters efficiently, apply the following APIs to annotate the pa 1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters. -2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known. +2. Apply [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known. * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless. @@ -551,7 +551,7 @@ Choosing when to use a device or logical axis depends on how much you want to co * **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming. -* **logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model. +* **Logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model. * **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful. diff --git a/docs/guides/regular_dict_upgrade_guide.rst b/docs/guides/regular_dict_upgrade_guide.rst index 8659ad5a3e..dda795f0de 100644 --- a/docs/guides/regular_dict_upgrade_guide.rst +++ b/docs/guides/regular_dict_upgrade_guide.rst @@ -120,9 +120,11 @@ Alternatively, the environment variable ``flax_return_frozendict`` (found `here `__) can be directly modified in the Flax source code. -Migration plan +Migration status -------------- -Currently ``flax_return_frozendict`` is set to True, meaning Flax will default to returning ``FrozenDicts``. -In the future this flag will be flipped to False, and Flax will instead default to returning regular dicts. -Eventually this feature flag will be removed once the migration is complete. \ No newline at end of file +As of July 19th, 2023, ``flax_return_frozendict`` is set to ``False`` (see +`#3193 `__), meaning Flax will default to +returning regular dicts from version `0.7.1 `__ +onward. This flag can be flipped to ``True`` temporarily to have Flax return +``Frozendicts``. However this feature flag will eventually be removed in the future. \ No newline at end of file diff --git a/docs/guides/transfer_learning.ipynb b/docs/guides/transfer_learning.ipynb index d2bd152304..78b0b8ff7f 100644 --- a/docs/guides/transfer_learning.ipynb +++ b/docs/guides/transfer_learning.ipynb @@ -165,7 +165,7 @@ "metadata": {}, "source": [ "## Transfering the parameters\n", - "Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location. This can be done by unfreezing `params`, updating the `backbone` parameters, and freezing the `params` again:" + "Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location (i.e. the `backbone`):" ] }, { @@ -174,11 +174,7 @@ "metadata": {}, "outputs": [], "source": [ - "import flax\n", - "\n", - "params = flax.core.unfreeze(params)\n", - "params['backbone'] = vision_model_vars['params']\n", - "params = flax.core.freeze(params)" + "params['backbone'] = vision_model_vars['params']" ] }, { @@ -247,13 +243,13 @@ "import optax\n", "\n", "partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}\n", - "param_partitions = flax.core.freeze(traverse_util.path_aware_map(\n", - " lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params))\n", + "param_partitions = traverse_util.path_aware_map(\n", + " lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)\n", "tx = optax.multi_transform(partition_optimizers, param_partitions)\n", "\n", "# visualize a subset of the param_partitions structure\n", "flat = list(traverse_util.flatten_dict(param_partitions).items())\n", - "flax.core.freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))" + "traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))" ] }, { diff --git a/docs/guides/transfer_learning.md b/docs/guides/transfer_learning.md index d467139267..8a563d498d 100644 --- a/docs/guides/transfer_learning.md +++ b/docs/guides/transfer_learning.md @@ -111,14 +111,10 @@ params = variables['params'] ``` ## Transfering the parameters -Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location. This can be done by unfreezing `params`, updating the `backbone` parameters, and freezing the `params` again: +Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location (i.e. the `backbone`): ```{code-cell} ipython3 -import flax - -params = flax.core.unfreeze(params) params['backbone'] = vision_model_vars['params'] -params = flax.core.freeze(params) ``` **Note:** if the model contains other variable collections such as `batch_stats`, these have to be transfered as well. @@ -153,13 +149,13 @@ from flax import traverse_util import optax partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()} -param_partitions = flax.core.freeze(traverse_util.path_aware_map( - lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)) +param_partitions = traverse_util.path_aware_map( + lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params) tx = optax.multi_transform(partition_optimizers, param_partitions) # visualize a subset of the param_partitions structure flat = list(traverse_util.flatten_dict(param_partitions).items()) -flax.core.freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))) +traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])) ``` To implement [differential learning rates](https://blog.slavv.com/differential-learning-rates-59eff5209a4f), the `optax.set_to_zero` can be replaced with any other optimizer, different optimizers and partitioning schemes can be selected depending on the task. For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation. diff --git a/docs/index.rst b/docs/index.rst index a041fb5de5..8f3a56fb3d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -318,8 +318,9 @@ Notable examples in Flax include: guides/index examples glossary + faq developer_notes/index philosophy contributing experimental - api_reference/index + api_reference/index \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index cc9c9e1873..600735f0b0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -32,4 +32,3 @@ tensorflow_text>=2.11.0 # WMT example # notebooks einops -transformers[flax] diff --git a/examples/imagenet/imagenet_fake_data_benchmark.py b/examples/imagenet/imagenet_fake_data_benchmark.py index 377a533acb..9e532adc90 100644 --- a/examples/imagenet/imagenet_fake_data_benchmark.py +++ b/examples/imagenet/imagenet_fake_data_benchmark.py @@ -22,7 +22,6 @@ import time from absl.testing import absltest -from absl.testing.flagsaver import flagsaver from flax.testing import Benchmark import jax import tensorflow_datasets as tfds diff --git a/examples/imagenet/train.py b/examples/imagenet/train.py index b83b97de68..a32b2283e6 100644 --- a/examples/imagenet/train.py +++ b/examples/imagenet/train.py @@ -25,7 +25,6 @@ from absl import logging from clu import metric_writers from clu import periodic_actions -import flax from flax import jax_utils from flax.training import checkpoints from flax.training import common_utils diff --git a/examples/linen_design_test/attention_simple.py b/examples/linen_design_test/attention_simple.py index e77d43c1b3..6490d66329 100644 --- a/examples/linen_design_test/attention_simple.py +++ b/examples/linen_design_test/attention_simple.py @@ -14,15 +14,13 @@ import functools from pprint import pprint -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union -from flax.core import Scope -from flax.core.frozen_dict import freeze, unfreeze +from typing import Any, Callable, Optional, Sequence +from flax.core.frozen_dict import unfreeze from flax.linen import initializers from flax.linen import Module, compact, vmap from flax.linen.linear import PrecisionLike import jax from jax import lax, numpy as jnp, random -import numpy as np class Dense(Module): diff --git a/examples/linen_design_test/autoencoder.py b/examples/linen_design_test/autoencoder.py index 7c6a1fc9c6..cec1c1dea9 100644 --- a/examples/linen_design_test/autoencoder.py +++ b/examples/linen_design_test/autoencoder.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, Tuple import jax -from jax import numpy as jnp, random, lax -import numpy as np +from jax import numpy as jnp, random from flax import linen as nn from flax.linen import Module, Dense, compact diff --git a/examples/linen_design_test/dense.py b/examples/linen_design_test/dense.py index 78088ccbbc..7a109a95fe 100644 --- a/examples/linen_design_test/dense.py +++ b/examples/linen_design_test/dense.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax -from jax import numpy as jnp, random, lax +from jax import lax from flax.linen import initializers from typing import Callable from flax.linen import Module, compact diff --git a/examples/linen_design_test/linear_regression.py b/examples/linen_design_test/linear_regression.py index 8bda1e1112..bd6a6812d8 100644 --- a/examples/linen_design_test/linear_regression.py +++ b/examples/linen_design_test/linear_regression.py @@ -13,8 +13,7 @@ # limitations under the License. import jax -from jax import numpy as jnp, random, lax, jit -from flax import linen as nn +from jax import numpy as jnp, jit from dense import Dense diff --git a/examples/linen_design_test/mlp_explicit.py b/examples/linen_design_test/mlp_explicit.py index 9953c4df4a..a2665018cb 100644 --- a/examples/linen_design_test/mlp_explicit.py +++ b/examples/linen_design_test/mlp_explicit.py @@ -13,14 +13,12 @@ # limitations under the License. from pprint import pprint -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Optional from flax.deprecated import nn -from flax.deprecated.nn import initializers from dense import Dense from flax.linen import Module import jax -from jax import lax, numpy as jnp, random -import numpy as np +from jax import numpy as jnp # Add `in_features` to the built-in Dense layer that normally works diff --git a/examples/linen_design_test/mlp_inline.py b/examples/linen_design_test/mlp_inline.py index b631d19d83..3759695adb 100644 --- a/examples/linen_design_test/mlp_inline.py +++ b/examples/linen_design_test/mlp_inline.py @@ -13,12 +13,10 @@ # limitations under the License. import jax -from jax import numpy as jnp, random, lax +from jax import numpy as jnp from flax import linen as nn -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable from flax.linen import Module, compact -import numpy as np -from pprint import pprint from dense import Dense diff --git a/examples/linen_design_test/mlp_lazy.py b/examples/linen_design_test/mlp_lazy.py index 7e246917bf..cff15d564b 100644 --- a/examples/linen_design_test/mlp_lazy.py +++ b/examples/linen_design_test/mlp_lazy.py @@ -13,10 +13,9 @@ # limitations under the License. import jax -from jax import numpy as jnp, random, lax +from jax import numpy as jnp from flax import linen as nn from flax.linen import Module -import numpy as np from pprint import pprint from dense import Dense diff --git a/examples/linen_design_test/tied_autoencoder.py b/examples/linen_design_test/tied_autoencoder.py deleted file mode 100644 index a7dadeef19..0000000000 --- a/examples/linen_design_test/tied_autoencoder.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -from jax import numpy as jnp, random, lax -from flax import linen as nn -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union -from flax.linen import Module, compact -import numpy as np -from dense import Dense - - -# TODO(avital, levskaya): resurrect this example once interactive api is restored. - -# class TiedAutoEncoder(Module): -# def setup(self): -# self.encoder = Dense(features=4, use_bias=False) - -# @property -# def decoder(self): -# return self.encoder.detached().attached(variables={ -# 'params': {"kernel": self.encoder.variables['params']['kernel'].T}}) - -# def __call__(self, x): -# z = self.encoder(x) -# x = self.decoder(z) -# return x - -# tae = TiedAutoEncoder(parent=None) -# tae = tae.initialized( -# {'params': random.key(42)}, -# jnp.ones((1, 16))) -# print("reconstruct", jnp.shape(tae(jnp.ones((1, 16))))) -# print("var shapes", jax.tree_util.tree_map(jnp.shape, tae.variables)) diff --git a/examples/linen_design_test/weight_std.py b/examples/linen_design_test/weight_std.py deleted file mode 100644 index c384a00b91..0000000000 --- a/examples/linen_design_test/weight_std.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -import jax -from jax import numpy as jnp, random, lax, jit -from flax import linen as nn -from flax.core.scope import Scope -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union -from flax.linen import Module, compact -import numpy as np -from dense import Dense -from flax.core.frozen_dict import freeze, unfreeze, FrozenDict - - -def standardize(x, axis, eps=1e-8): - x = x - jnp.mean(x, axis=axis, keepdims=True) - x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps) - return x - - -# TODO(avital, levskaya): resurrect this example once interactive api is restored. - -# A wrapper that calls through a simple module with standardized parameters. -# -# Note that StdWeight is /not/ a module, hence it doesn't add another layer -# of depth in the variable dict (i.e. this is a "transparent module") -# @dataclass -# class StdWeight: -# module: Module - -# def __call__(self, x): -# # TODO: Think about how this modifies other state -# if not 'params' in self.module.variables: -# # initialize parameters -# self.module(x) - -# param = self.module.variables['params'] -# # Make a copy because `param` is (and should be) frozen. We're only transforming -# # the parameters, not mutating them. -# std_param = param.copy(kernel=standardize(param['kernel'], axis=[0, 1])) -# return self.module.clone(parent=None).apply({'params': std_param}, x) - -# class MyModule(Module): -# def __call__(self, x): -# module = Dense(self, 3) -# std_module = StdWeight(module) -# return std_module(x) - -# m_variables = MyModule().init({'params': jax.random.key(10)}, jnp.ones((1, 4))) -# print(m_variables) diff --git a/examples/seq2seq/models.py b/examples/seq2seq/models.py index 5870a62e90..6a9d43c350 100644 --- a/examples/seq2seq/models.py +++ b/examples/seq2seq/models.py @@ -17,13 +17,11 @@ # See issue #620. # pytype: disable=wrong-keyword-args -import functools -from typing import Any, Tuple +from typing import Tuple from flax import linen as nn import jax import jax.numpy as jnp -import numpy as np Array = jax.Array PRNGKey = jax.Array diff --git a/examples/sst2/models_test.py b/examples/sst2/models_test.py index c1a42a0c02..bea495d1ec 100644 --- a/examples/sst2/models_test.py +++ b/examples/sst2/models_test.py @@ -17,7 +17,6 @@ from absl.testing import parameterized import models import jax -from jax import numpy as jnp import jax.test_util import numpy as np diff --git a/examples/wmt/bleu.py b/examples/wmt/bleu.py index e12911dd6e..ac69cc956d 100644 --- a/examples/wmt/bleu.py +++ b/examples/wmt/bleu.py @@ -44,7 +44,6 @@ import unicodedata import numpy as np -import six class UnicodeRegex: diff --git a/examples/wmt/models.py b/examples/wmt/models.py index 0f2fd2f962..6ed08ccd23 100644 --- a/examples/wmt/models.py +++ b/examples/wmt/models.py @@ -299,7 +299,7 @@ def __call__( broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, - )(y, encoded, encoder_decoder_mask) + )(y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=config.dropout_rate)( y, deterministic=config.deterministic diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 8c5c5646f1..4e7f25fb23 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -15,7 +15,7 @@ """Frozen Dictionary.""" import collections -from typing import Any, Dict, Hashable, Optional, Mapping, Tuple, TypeVar, Union +from typing import Any, Dict, Hashable, Mapping, Tuple, TypeVar, Union from types import MappingProxyType from flax import serialization diff --git a/flax/core/lift.py b/flax/core/lift.py index 4dba2655c2..42d7acbf18 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -641,18 +641,20 @@ def vmap( Args: fn: the function to be transformed. - variable_axes: the variable collections that are lifted into the - batching transformation. Use `None` to indicate a broadcasted - collection or an integer to map over an axis. - split_rngs: Split PRNG sequences will be different for each index - of the batch dimension. Unsplit PRNGs will be broadcasted. + variable_axes: the variable collections that are lifted into the batching + transformation. Use `None` to indicate a broadcasted collection or an + integer to map over an axis. + split_rngs: Split PRNG sequences will be different for each index of the + batch dimension. Unsplit PRNGs will be broadcasted. in_axes: Specifies the mapping of the input arguments (see `jax.vmap). out_axes: Specifies the mapping of the return value (see `jax.vmap). - axis_size: Specifies the size of the batch axis. This only needs - to be specified if it cannot be derived from the input arguments. - axis_name: Specifies a name for the batch axis. Can be used together - with parallel reduction primitives (e.g. `jax.lax.pmean`, - `jax.lax.ppermute`, etc.) + axis_size: Specifies the size of the batch axis. This only needs to be + specified if it cannot be derived from the input arguments. + axis_name: Specifies a name for the batch axis. Can be used together with + parallel reduction primitives (e.g. `jax.lax.pmean`, `jax.lax.ppermute`, + etc.). Note, this is only used for pmap and shmap. For SPMD jit, you do + not need to manually synchronize. Just make sure that the axes are + correctly annotated and XLA:SPMD will insert the necessary collectives. spmd_axis_name: Axis name added to any pjit sharding constraints appearing in `fn`. See also https://github.com/google/flax/blob/main/flax/linen/partitioning.py. diff --git a/flax/core/meta.py b/flax/core/meta.py index b21f2a2ee6..b589f90bad 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -23,7 +23,7 @@ import abc import functools -from typing import Any, Callable, Dict, Generic, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar, Union from flax import errors from flax import struct @@ -206,7 +206,7 @@ def __call__(self, x): mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. - # this way we can determinte the PartitionSpecs for the init variables + # this way we can determine the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) diff --git a/flax/core/nn/attention.py b/flax/core/nn/attention.py index bea380a2ea..212682e4f0 100644 --- a/flax/core/nn/attention.py +++ b/flax/core/nn/attention.py @@ -17,10 +17,7 @@ from collections.abc import Iterable # pylint: disable=g-importing-member import functools from typing import Any, Callable, Union -import warnings -from . import stochastic -from flax import jax_utils from flax import struct from flax.core import Scope from flax.linen import initializers diff --git a/flax/core/scope.py b/flax/core/scope.py index 6059dcb9ad..46819a76b7 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -45,7 +45,6 @@ from flax import traceback_util from flax.ids import uuid import jax -from jax import config as jax_config from jax import numpy as jnp from jax import random from jax import tree_util @@ -955,16 +954,13 @@ def param( """ self.reserve(name, 'params') if self.has_variable('params', name): - abs_rng = jax.ShapeDtypeStruct( - random.default_prng_impl().key_shape, jnp.uint32 - ) value = self.get_variable('params', name) # Validate that the shape of the init_fn output is the same as the shape # of the existing parameter. This is to make sure that the hparams set up # in a Flax Module match the shapes coming in during apply, and if not, # catch it with an error message. # NOTE: We could consider moving this to `self.` - abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng) + abs_value = jax.eval_shape(lambda: init_fn(random.key(0), *init_args)) abs_value_flat = jax.tree_util.tree_leaves(abs_value) value_flat = jax.tree_util.tree_leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): @@ -1202,10 +1198,10 @@ def _is_valid_rng(rng: Array): return rng.shape == () # Handle old-style raw PRNG keys - if ( - rng.shape != random.default_prng_impl().key_shape - or rng.dtype != jnp.uint32 - ): + expected_rng = jax.eval_shape( + lambda s: jax.random.key_data(jax.random.key(s)), 0 + ) + if (rng.shape, rng.dtype) != (expected_rng.shape, expected_rng.dtype): return False return True diff --git a/flax/cursor.py b/flax/cursor.py index 8f7be74c85..d919782501 100644 --- a/flax/cursor.py +++ b/flax/cursor.py @@ -13,7 +13,7 @@ # limitations under the License. import enum -from typing import Any, Callable, Dict, Generator, Generic, Mapping, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable +from typing import Any, Callable, Dict, Generator, Generic, Mapping, Optional, Protocol, TypeVar, runtime_checkable from flax.core import FrozenDict from flax.errors import CursorFindError, TraverseTreeError import dataclasses diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 6df28ca67b..ac7e76995f 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -69,11 +69,7 @@ make_causal_mask as make_causal_mask, ) from .combinators import Sequential as Sequential -from .fp8_ops import ( - compute_scale as fp8_compute_scale, - quantize_dequantize as fp8_quantize_dequantize, - Fp8DenseGeneralOp as Fp8DenseGeneralOp, -) +from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp from .initializers import ( ones_init as ones_init, ones as ones, @@ -108,6 +104,7 @@ LayerNorm as LayerNorm, RMSNorm as RMSNorm, SpectralNorm as SpectralNorm, + WeightNorm as WeightNorm ) from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool) from .recurrent import ( diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 851434f15a..d6ca8c166d 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,7 +15,8 @@ """Attention core modules for Flax.""" import functools -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union, overload +import warnings from flax.linen import initializers from flax.linen.dtypes import promote_dtype @@ -33,7 +34,7 @@ import jax.numpy as jnp -PRNGKey = Any +PRNGKey = jax.Array Shape = Tuple[int, ...] Dtype = Any Array = Any @@ -243,36 +244,113 @@ class MultiHeadDotProductAttention(Module): decode: bool = False normalize_qk: bool = False # Deprecated, will be removed. - qkv_dot_general: DotGeneralT = lax.dot_general - out_dot_general: DotGeneralT = lax.dot_general + qkv_dot_general: Optional[DotGeneralT] = None + out_dot_general: Optional[DotGeneralT] = None qkv_dot_general_cls: Any = None out_dot_general_cls: Any = None + @overload + def __call__( + self, + inputs_q: Array, + inputs_k: Optional[Array] = None, + inputs_v: Optional[Array] = None, + *, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + dropout_rng: Optional[PRNGKey] = None, + ): + ... + + @overload + def __call__( + self, + inputs_q: Array, + *, + inputs_kv: Array = None, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + dropout_rng: Optional[PRNGKey] = None, + ): + ... + @compact def __call__( self, inputs_q: Array, - inputs_kv: Array, + inputs_k: Optional[Array] = None, + inputs_v: Optional[Array] = None, + *, + inputs_kv: Optional[Array] = None, mask: Optional[Array] = None, deterministic: Optional[bool] = None, + dropout_rng: Optional[PRNGKey] = None ): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. + If both inputs_k and inputs_v are None, they will both copy the value of + inputs_q (self attention). + If only inputs_v is None, it will copy the value of inputs_k. + Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. - inputs_kv: key/values of shape `[batch_sizes..., length, features]`. + inputs_k: key of shape `[batch_sizes..., length, features]`. If None, + inputs_k will copy the value of inputs_q. + inputs_v: values of shape `[batch_sizes..., length, features]`. If None, + inputs_v will copy the value of inputs_k. + inputs_kv: key/values of shape `[batch_sizes..., length, features]`. If + None, inputs_kv will copy the value of inputs_q. This arg will be + deprecated soon. Use inputs_k and inputs_v instead. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Attention weights are masked out if their corresponding mask value is `False`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. + dropout_rng: optional rng key to pass to the attention layer's dropout + mask. Otherwise, self.make_rng('dropout') is used instead. Returns: output of shape `[batch_sizes..., length, features]`. """ + if inputs_kv is not None: + if inputs_k is not None or inputs_v is not None: + raise ValueError('If either `inputs_k` or `inputs_v` is not None, ' + '`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` ' + 'and `inputs_v` must be None. We recommend using `inputs_k` and ' + '`inputs_v` args, since `inputs_kv` will be deprecated soon. See ' + 'https://github.com/google/flax/discussions/3389 for more ' + 'information.') + inputs_k = inputs_v = inputs_kv + warnings.warn('The inputs_kv arg will be deprecated soon. ' + 'Use inputs_k and inputs_v instead. See ' + 'https://github.com/google/flax/discussions/3389 ' + 'for more information.', + DeprecationWarning) + else: + if inputs_k is None: + if inputs_v is not None: + raise ValueError('`inputs_k` cannot be None if `inputs_v` is not None. ' + 'To have both `inputs_k` and `inputs_v` be the same value, pass in the ' + 'value to `inputs_k` and leave `inputs_v` as None.') + inputs_k = inputs_q + if inputs_v is None: + inputs_v = inputs_k + elif inputs_v.shape[-1] == inputs_v.shape[-2]: + warnings.warn(f"You are passing an array of shape {inputs_v.shape} " + "to the `inputs_v` arg, when you may have intended " + "to pass it to the `mask` arg. As of Flax version " + "0.7.4, the function signature of " + "MultiHeadDotProductAttention's `__call__` method " + "has changed to `__call__(inputs_q, inputs_k=None, " + "inputs_v=None, *, inputs_kv=None, mask=None, " + "deterministic=None)`. Use the kwarg `mask` instead. " + "See https://github.com/google/flax/discussions/3389 " + "and read the docstring for more information.", + DeprecationWarning) + features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( @@ -298,8 +376,8 @@ def __call__( # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = ( dense(name='query')(inputs_q), - dense(name='key')(inputs_kv), - dense(name='value')(inputs_kv), + dense(name='key')(inputs_k), + dense(name='value')(inputs_v), ) if self.normalize_qk: @@ -361,14 +439,13 @@ def __call__( ), ) - dropout_rng = None if ( self.dropout_rate > 0.0 ): # Require `deterministic` only if using dropout. m_deterministic = merge_param( 'deterministic', self.deterministic, deterministic ) - if not m_deterministic: + if not m_deterministic and dropout_rng is None: dropout_rng = self.make_rng('dropout') else: m_deterministic = True @@ -412,6 +489,7 @@ def __call__( # type: ignore inputs_q: Array, mask: Optional[Array] = None, deterministic: Optional[bool] = None, + dropout_rng: Optional[PRNGKey] = None ): """Applies multi-head dot product self-attention on the input data. @@ -429,8 +507,13 @@ def __call__( # type: ignore Returns: output of shape `[batch_sizes..., length, features]`. """ + warnings.warn('SelfAttention will be deprecated soon. Use ' + '`MultiHeadDotProductAttention.__call__(inputs_q)` instead. ' + 'See https://github.com/google/flax/discussions/3389 ' + 'for more information.', + DeprecationWarning) return super().__call__( - inputs_q, inputs_q, mask, deterministic=deterministic + inputs_q, mask=mask, deterministic=deterministic, dropout_rng=dropout_rng ) diff --git a/flax/linen/dtypes.py b/flax/linen/dtypes.py index 463b0f23c4..bef29a2f02 100644 --- a/flax/linen/dtypes.py +++ b/flax/linen/dtypes.py @@ -30,7 +30,6 @@ from typing import Any, Optional, List from jax import numpy as jnp -import jax Dtype = Any diff --git a/flax/linen/experimental/layers_with_named_axes.py b/flax/linen/experimental/layers_with_named_axes.py index 25ee5e6276..24fde145ba 100644 --- a/flax/linen/experimental/layers_with_named_axes.py +++ b/flax/linen/experimental/layers_with_named_axes.py @@ -73,7 +73,7 @@ class Dense(nn.Module): ) kernel_axes: Tuple[str, ...] = () # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @nn.compact @@ -98,8 +98,10 @@ def __call__(self, inputs: Array) -> Array: if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general y = dot_general( inputs, kernel, diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 575c1e1104..8199d58cb7 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -12,48 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable from functools import partial from flax.linen import initializers -from flax.linen.module import Module +from flax.linen import module from jax import custom_vjp from jax import lax from jax import numpy as jnp from jax import random -# Type annotations -Array = jnp.ndarray -Dtype = jnp.dtype -PRNGKey = jnp.ndarray +OVERWRITE_WITH_GRADIENT = '_overwrite_with_gradient' -class FP8Helper: - FP8_COLLECTION_NAME: str = "fp8_params" def get_fp8_max(fp8_dtype, out_dtype): assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2) return jnp.finfo(fp8_dtype).max.astype(out_dtype) + def quantize(x, q_dtype, scale, compute_dtype): - # We need to explicitly cast the max value to compute_dtype, otherwise the jax - # dtype promotion will cast the scaled_x to fp32 in the following ops, which - # would violate the fp8-matmul pattern matching. + # Explicitly cast the max values to the compute dtype to avoid unnecessary + # casting to FP32 during the subsequent math operations." dtype_max = get_fp8_max(q_dtype, compute_dtype) - scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape) - clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max) - return clipped_x.astype(q_dtype) + def dequantize(x, dq_dtype, scale): return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape) + def quantize_dequantize(x, q_dtype, scale, compute_dtype): qx = quantize(x, q_dtype, scale, compute_dtype) return dequantize(qx, x.dtype, scale) + def compute_scale(amax, scale, fp8_max, margin=0): """Default function to convert amax to scaling factor.""" # This function copied from the TransformerEngine is used to compute its @@ -67,38 +61,44 @@ def compute_scale(amax, scale, fp8_max, margin=0): sf = jnp.where(exp < 0, 1.0 / sf, sf) return 1.0 / sf + def compute_scale_and_amax_history(x, q_dtype, scale, amax_history): dtype_max = get_fp8_max(q_dtype, jnp.float32) - amax_update = jnp.max(jnp.abs(x)).astype(scale.dtype) - new_amax_history = \ - jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update) - - amax_from_history = jnp.max(new_amax_history, axis=0) + new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update) + amax_from_history = jnp.max(new_history, axis=0) new_scale = compute_scale(amax_from_history, scale, dtype_max) - return new_scale, new_amax_history + return new_scale, new_history + def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): qx = quantize_dequantize(x, q_dtype, scale, compute_dtype) - new_scale, new_amax_history = compute_scale_and_amax_history( - x, q_dtype, scale, amax_history) - return qx, new_scale, new_amax_history + new_scale, new_history = compute_scale_and_amax_history( + x, q_dtype, scale, amax_history + ) + return qx, new_scale, new_history + @partial(custom_vjp, nondiff_argnums=(0,)) def in_qdq(compute_dtype, inp, scale, amax_history): qin, _, _ = qdq_and_return( - inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype) + inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype + ) return qin + def in_qdq_fwd(compute_dtype, inp, scale, amax_history): - qin, new_scale, new_amax_history = qdq_and_return( - inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype) - return qin, (new_scale, new_amax_history) + qin, new_scale, new_history = qdq_and_return( + inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype + ) + return qin, (new_scale, new_history) + def in_qdq_bwd(compute_dtype, res, g): - new_scale, new_amax_history = res + new_scale, new_history = res q_g = g - return q_g, new_scale, new_amax_history + return q_g, new_scale, new_history + in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd) @@ -107,34 +107,23 @@ def in_qdq_bwd(compute_dtype, res, g): def out_qdq(compute_dtype, out, scale, amax_history): return out + def out_qdq_fwd(compute_dtype, out, scale, amax_history): return out, (scale, amax_history) + def out_qdq_bwd(compute_dtype, res, g): scale, amax_history = res - q_g, new_scale, new_amax_history = qdq_and_return( - g, jnp.float8_e5m2, scale, amax_history, compute_dtype) - return q_g, new_scale, new_amax_history - -out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) - -def fp8_dot_general(lhs, rhs, dimension_numbers, precision, compute_dtype, - lhs_scale, lhs_amax_history, rhs_scale, rhs_amax_history, - dout_scale, dout_amax_history): - """Perform dot_general. """ + q_g, new_scale, new_history = qdq_and_return( + g, jnp.float8_e5m2, scale, amax_history, compute_dtype + ) + return q_g, new_scale, new_history - lhs_qdq = in_qdq(compute_dtype, lhs, lhs_scale, lhs_amax_history) - rhs_qdq = in_qdq(compute_dtype, rhs, rhs_scale, rhs_amax_history) - - output_qdq = lax.dot_general(lhs_qdq, rhs_qdq, dimension_numbers, precision) - - out = out_qdq(compute_dtype, output_qdq, dout_scale, dout_amax_history) - - return out +out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) -class Fp8DenseGeneralOp(Module): +class Fp8DotGeneralOp(module.Module): amax_history_length: int = 1024 def setup(self) -> None: @@ -152,47 +141,46 @@ def setup(self) -> None: ) self.input_amax_history = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'input_amax_history', - *amax_history_args) + OVERWRITE_WITH_GRADIENT, 'input_amax_history', *amax_history_args) self.kernel_amax_history = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'kernel_amax_history', - *amax_history_args) + OVERWRITE_WITH_GRADIENT, 'kernel_amax_history', *amax_history_args) self.output_grad_amax_history = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'output_grad_amax_history', - *amax_history_args) + OVERWRITE_WITH_GRADIENT, 'output_grad_amax_history', *amax_history_args) self.input_scale = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'input_scale', - *scale_args) + OVERWRITE_WITH_GRADIENT, 'input_scale', *scale_args) self.kernel_scale = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'kernel_scale', - *scale_args) + OVERWRITE_WITH_GRADIENT, 'kernel_scale', *scale_args) self.output_grad_scale = self.variable( - FP8Helper.FP8_COLLECTION_NAME, - 'output_grad_scale', - *scale_args) + OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args) - def __call__(self, *args, **kwargs) -> Array: + def __call__(self, *args, **kwargs) -> jnp.ndarray: assert len(args) == 3 - inputs = args[0] - kernel = args[1] + x = args[0] + k = args[1] dimension_numbers = args[2] precision = kwargs['precision'] - comp_dtype = kernel.dtype - inputs = jnp.asarray(inputs, comp_dtype) - - out = fp8_dot_general(inputs, kernel, dimension_numbers, precision, - comp_dtype, self.input_scale.value, - self.input_amax_history.value, - self.kernel_scale.value, self.kernel_amax_history.value, - self.output_grad_scale.value, - self.output_grad_amax_history.value) - return out + + # Use the `k.dtype` since it aligns with the `dtype` of its layers, + # namely, the computation data type. + comp_dtype = k.dtype + x = jnp.asarray(x, comp_dtype) + + x_qdq = in_qdq( + comp_dtype, x, self.input_scale.value, self.input_amax_history.value + ) + k_qdq = in_qdq( + comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value + ) + y_qdq = lax.dot_general(x_qdq, k_qdq, dimension_numbers, precision) + y = out_qdq( + comp_dtype, + y_qdq, + self.output_grad_scale.value, + self.output_grad_amax_history.value + ) + + return y diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 4a23bd3c1c..77d3e65f87 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -34,7 +34,6 @@ import jax from jax import eval_shape from jax import lax -from jax import random from jax.core import ShapedArray import jax.numpy as jnp import numpy as np @@ -98,7 +97,7 @@ class DenseGeneral(Module): ) precision: PrecisionLike = None # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @compact @@ -178,8 +177,10 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general out = dot_general( inputs, kernel, @@ -218,7 +219,7 @@ class Dense(Module): initializers.zeros_init() ) # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @compact @@ -247,8 +248,10 @@ def __call__(self, inputs: Array) -> Array: if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general y = dot_general( inputs, kernel, @@ -350,7 +353,7 @@ class _Conv(Module): initializers.zeros_init() ) # Deprecated. Will be removed. - conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated + conv_general_dilated: Optional[ConvGeneralDilatedT] = None conv_general_dilated_cls: Any = None @property @@ -466,8 +469,10 @@ def maybe_broadcast( # create the unshared convolution kernel. if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() - else: + elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated + else: + conv_general_dilated = lax.conv_general_dilated conv_output_shape = eval_shape( lambda lhs, rhs: conv_general_dilated( # pylint: disable=g-long-lambda lhs=lhs, @@ -517,8 +522,10 @@ def maybe_broadcast( if self.shared_weights: if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() - else: + elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated + else: + conv_general_dilated = lax.conv_general_dilated y = conv_general_dilated( inputs, kernel, diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 899351962f..daf79f0981 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -18,9 +18,9 @@ import functools from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union -from flax.linen.dtypes import canonicalize_dtype -from flax.linen.module import Module, compact, merge_param # pylint: disable=g-multiple-import -from flax.linen.transforms import map_variables +from flax.linen import dtypes +from flax.linen import module +from flax.linen import transforms import jax from jax import lax from jax.nn import initializers @@ -33,6 +33,13 @@ Dtype = Any # this could be a real type? Axes = Union[int, Sequence[int]] +field = dataclasses.field +canonicalize_dtype = dtypes.canonicalize_dtype +compact = module.compact +Module = module.Module +merge_param = module.merge_param +map_variables = transforms.map_variables + def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: """Returns a tuple of deduplicated, sorted, and positive axes.""" @@ -57,6 +64,7 @@ def _compute_stats( axis_index_groups: Any = None, use_mean: bool = True, use_fast_variance: bool = True, + mask: Optional[Array] = None, ): """Computes mean and variance statistics. @@ -75,13 +83,18 @@ def _compute_stats( axes: The axes in ``x`` to compute mean and variance statistics for. dtype: Optional dtype specifying the minimal precision. Statistics are always at least float32 for stability (default: dtype of x). - axis_name: Optional name for the pmapped axis to compute mean over. + axis_name: Optional name for the pmapped axis to compute mean over. Note, + this is only used for pmap and shard map. For SPMD jit, you do not need to + manually synchronize. Just make sure that the axes are correctly annotated + and XLA:SPMD will insert the necessary collectives. axis_index_groups: Optional axis indices. use_mean: If true, calculate the mean from the input and use it when computing the variance. If false, set the mean to zero and compute the variance without subtracting the mean. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. + mask: Binary array of shape broadcastable to `inputs` tensor, indicating + the positions for which the mean and variance should be computed. Returns: A pair ``(mean, var)``. @@ -94,8 +107,8 @@ def _compute_stats( x = jnp.asarray(x, dtype) axes = _canonicalize_axes(x.ndim, axes) - def maybe_distributed_mean(*xs): - mus = tuple(x.mean(axes) for x in xs) + def maybe_distributed_mean(*xs, mask=None): + mus = tuple(x.mean(axes, where=mask) for x in xs) if axis_name is None: return mus if len(xs) > 1 else mus[0] else: @@ -112,15 +125,17 @@ def maybe_distributed_mean(*xs): if use_mean: if use_fast_variance: - mu, mu2 = maybe_distributed_mean(x, _abs_sq(x)) + mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0.0, mu2 - _abs_sq(mu)) else: - mu = maybe_distributed_mean(x) - var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes))) + mu = maybe_distributed_mean(x, mask=mask) + var = maybe_distributed_mean( + _abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask + ) else: - var = maybe_distributed_mean(_abs_sq(x)) + var = maybe_distributed_mean(_abs_sq(x), mask=mask) mu = jnp.zeros_like(var) return mu, var @@ -188,7 +203,7 @@ def _normalize( ).reshape(feature_shape) y += bias args.append(bias) - dtype = canonicalize_dtype(*args, dtype=dtype) + dtype = dtypes.canonicalize_dtype(*args, dtype=dtype) return jnp.asarray(y, dtype) @@ -257,6 +272,9 @@ class BatchNorm(Module): scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). + Note, this is only used for pmap and shard map. For SPMD jit, you do not + need to manually synchronize. Just make sure that the axes are correctly + annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the @@ -281,7 +299,7 @@ class BatchNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, use_running_average: Optional[bool] = None): + def __call__(self, x, use_running_average: Optional[bool] = None, mask=None): """Normalizes the input using batch statistics. NOTE: @@ -295,12 +313,14 @@ def __call__(self, x, use_running_average: Optional[bool] = None): x: the input to be normalized. use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. + mask: Binary array of shape broadcastable to `inputs` tensor, indicating + the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ - use_running_average = merge_param( + use_running_average = module.merge_param( 'use_running_average', self.use_running_average, use_running_average ) feature_axes = _canonicalize_axes(x.ndim, self.axis) @@ -327,6 +347,7 @@ def __call__(self, x, use_running_average: Optional[bool] = None): axis_name=self.axis_name if not self.is_initializing() else None, axis_index_groups=self.axis_index_groups, use_fast_variance=self.use_fast_variance, + mask=mask, ) if not self.is_initializing(): @@ -375,7 +396,10 @@ class LayerNorm(Module): axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the - array being normalized is sharded across devices within a pmap. + array being normalized is sharded across devices within a pmap or shard + map. For SPMD jit, you do not need to manually synchronize. Just make sure + that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the @@ -466,7 +490,10 @@ class RMSNorm(Module): axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the - array being normalized is sharded across devices within a pmap. + array being normalized is sharded across devices within a pmap or shard + map. For SPMD jit, you do not need to manually synchronize. Just make sure + that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the @@ -546,7 +573,10 @@ class GroupNorm(Module): axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the - array being normalized is sharded across devices within a pmap. + array being normalized is sharded across devices within a pmap or shard + map. For SPMD jit, you do not need to manually synchronize. Just make sure + that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the @@ -644,8 +674,9 @@ def __call__(self, x): class SpectralNorm(Module): - """Spectral normalization. See: + """Spectral normalization. + See: - https://arxiv.org/abs/1802.05957 - https://arxiv.org/abs/1805.08318 - https://arxiv.org/abs/1809.11096 @@ -742,10 +773,10 @@ def __call__(self, *args, update_stats: bool, **kwargs): Args: *args: positional arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. - update_stats: if True, update the internal ``u`` vector and ``sigma`` value - after computing their updated values using power iteration. This will help - the power iteration method approximate the true singular value more - accurately over time. + update_stats: if True, update the internal ``u`` vector and ``sigma`` + value after computing their updated values using power iteration. This + will help the power iteration method approximate the true singular value + more accurately over time. **kwargs: keyword arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. @@ -756,12 +787,11 @@ def __call__(self, *args, update_stats: bool, **kwargs): def layer_forward(layer_instance): return layer_instance(*args, **kwargs) - return map_variables( + return transforms.map_variables( layer_forward, trans_in_fn=lambda vs: jax.tree_util.tree_map_with_path( functools.partial( self._spectral_normalize, - layer_instance_name=self.layer_instance.name, update_stats=update_stats, ), vs, @@ -770,7 +800,7 @@ def layer_forward(layer_instance): mutable=True, )(self.layer_instance) - def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): + def _spectral_normalize(self, path, vs, update_stats): """Compute the largest singular value using power iteration and normalize the variables ``vs`` using this value. This is intended to be a helper function used in this Module's ``__call__`` method in conjunction with @@ -779,8 +809,6 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): Args: path: dict key path, used for naming the ``u`` and ``sigma`` variables vs: variables to be spectral normalized - layer_instance_name: name of the underlying ``self.layer_instance``, - used for naming the ``u`` and ``sigma`` variables update_stats: if True, update the ``u`` vector and ``sigma`` variables after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value @@ -789,7 +817,8 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): value = jnp.asarray(vs) value_shape = value.shape - # Skip and return value if input is scalar, vector or if number of power iterations is less than 1 + # Skip and return value if input is scalar, vector or if number of power + # iterations is less than 1 if value.ndim <= 1 or self.n_steps < 1: return value # Handle higher-order tensors. @@ -802,7 +831,7 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): value = jnp.reshape(value, (-1, value.shape[-1])) u_var_name = ( - layer_instance_name + self.layer_instance.name + '/' + '/'.join((dict_key.key for dict_key in path[1:])) + '/u' @@ -819,7 +848,7 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): ) u0 = u_var.value sigma_var_name = ( - layer_instance_name + self.layer_instance.name + '/' + '/'.join((dict_key.key for dict_key in path[1:])) + '/sigma' @@ -847,5 +876,182 @@ def _spectral_normalize(self, path, vs, layer_instance_name, update_stats): u_var.value = u0 sigma_var.value = sigma - dtype = canonicalize_dtype(vs, u0, v0, sigma, dtype=self.dtype) + dtype = dtypes.canonicalize_dtype(vs, u0, v0, sigma, dtype=self.dtype) + return jnp.asarray(value_bar, dtype) + + +class WeightNorm(Module): + """L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf). + + Weight normalization normalizes the weight params so that the l2-norm of + the matrix is equal to 1. This is implemented as a layer wrapper where + each wrapped layer will have its params l2-normalized before computing + its ``__call__`` output. + + Example:: + + class Baz(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(2)(x) + + class Bar(nn.Module): + @nn.compact + def __call__(self, x): + x = Baz()(x) + x = nn.Dense(3)(x) + x = Baz()(x) + x = nn.Dense(3)(x) + return x + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(3)(x) + # l2-normalize all params of the second Dense layer + x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x) + x = nn.Dense(5)(x) + # l2-normalize all kernels in the Bar submodule and all params in the + # Baz submodule + x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x) + return x + + # init + x = jnp.ones((1, 2)) + model = Foo() + variables = model.init(jax.random.key(0), x) + + variables + # { + # params: { + # ... + # WeightNorm_0: { + # Dense_1/bias/scale: Array([1., 1., 1., 1.], dtype=float32), + # Dense_1/kernel/scale: Array([1., 1., 1., 1.], dtype=float32), + # }, + # ... + # WeightNorm_1: { + # Bar_0/Baz_0/Dense_0/bias/scale: Array([1., 1.], dtype=float32), + # Bar_0/Baz_0/Dense_0/kernel/scale: Array([1., 1.], dtype=float32), + # Bar_0/Baz_1/Dense_0/bias/scale: Array([1., 1.], dtype=float32), + # Bar_0/Baz_1/Dense_0/kernel/scale: Array([1., 1.], dtype=float32), + # Bar_0/Dense_0/kernel/scale: Array([1., 1., 1.], dtype=float32), + # Bar_0/Dense_1/kernel/scale: Array([1., 1., 1.], dtype=float32), + # }, + # ... + # } + # } + + Attributes: + layer_instance: Module instance that is wrapped with WeightNorm + epsilon: A small float added to l2-normalization to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_scale: If True, creates a learnable variable ``scale`` that is + multiplied to the ``layer_instance`` variables after l2-normalization. + scale_init: Initialization function for the scaling function. + feature_axes: The feature axes dimension(s). The l2-norm is calculated by + reducing the ``layer_instance`` variables over the remaining (non-feature) + axes. Therefore a separate l2-norm value is calculated and a separate + scale (if ``use_scale=True``) is learned for each specified feature. By + default, the trailing dimension is treated as the feature axis. + variable_filter: An optional iterable that contains string items. The + WeightNorm layer will selectively apply l2-normalization to the + ``layer_instance`` variables whose key path (delimited by '/') has a match + with ``variable_filter``. For example, ``variable_filter={'kernel'}`` will + only apply l2-normalization to variables whose key path contains 'kernel'. + By default, ``variable_filter={'kernel'}``. + """ + + layer_instance: Module + epsilon: float = 1e-12 + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + use_scale: bool = True + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + feature_axes: Optional[Axes] = -1 + variable_filter: Optional[Iterable] = dataclasses.field( + default_factory=lambda: {'kernel'} + ) + + @compact + def __call__(self, *args, **kwargs): + """Compute the l2-norm of the weights in ``self.layer_instance`` + and normalize the weights using this value before computing the + ``__call__`` output. + + Args: + *args: positional arguments to be passed into the call method of the + underlying layer instance in ``self.layer_instance``. + **kwargs: keyword arguments to be passed into the call method of the + underlying layer instance in ``self.layer_instance``. + + Returns: + Output of the layer using l2-normalized weights. + """ + + def layer_forward(layer_instance): + return layer_instance(*args, **kwargs) + + return transforms.map_variables( + layer_forward, + trans_in_fn=lambda vs: jax.tree_util.tree_map_with_path( + self._l2_normalize, + vs, + ), + init=self.is_initializing(), + )(self.layer_instance) + + def _l2_normalize(self, path, vs): + """Compute the l2-norm and normalize the variables ``vs`` using this + value. This is intended to be a helper function used in this Module's + ``__call__`` method in conjunction with ``nn.transforms.map_variables`` + and ``jax.tree_util.tree_map_with_path``. + + Args: + path: dict key path, used for naming the ``scale`` variable + vs: variables to be l2-normalized + """ + value = jnp.asarray(vs) + str_path = ( + self.layer_instance.name + + '/' + + '/'.join((dict_key.key for dict_key in path[1:])) + ) + if self.variable_filter: + for variable_name in self.variable_filter: + if variable_name in str_path: + break + else: + return value + + if self.feature_axes is None: + feature_axes = () + reduction_axes = tuple(i for i in range(value.ndim)) + else: + feature_axes = _canonicalize_axes(value.ndim, self.feature_axes) + reduction_axes = tuple( + i for i in range(value.ndim) if i not in feature_axes + ) + + feature_shape = [1] * value.ndim + reduced_feature_shape = [] + for ax in feature_axes: + feature_shape[ax] = value.shape[ax] + reduced_feature_shape.append(value.shape[ax]) + + value_bar = _l2_normalize(value, axis=reduction_axes, eps=self.epsilon) + + args = [vs] + if self.use_scale: + scale = self.param( + str_path + '/scale', + self.scale_init, + reduced_feature_shape, + self.param_dtype, + ).reshape(feature_shape) + value_bar *= scale + args.append(scale) + + dtype = dtypes.canonicalize_dtype(*args, dtype=self.dtype) return jnp.asarray(value_bar, dtype) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 7b39db5998..3aac6563f6 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -20,7 +20,7 @@ from abc import ABCMeta from functools import partial # pylint: disable=g-importing-member -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, TypeVar, Union from absl import logging from flax.core import lift from flax.core.frozen_dict import FrozenDict @@ -130,7 +130,7 @@ class LSTMCell(RNNCellBase, metaclass=RNNCellCompatibilityMeta): Attributes: features: number of output features. - gate_fn: activation function used for gates (default: sigmoid) + gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform @@ -406,8 +406,8 @@ class GRUCell(RNNCellBase, metaclass=RNNCellCompatibilityMeta): .. math:: \begin{array}{ll} - r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ - z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array} @@ -415,7 +415,8 @@ class GRUCell(RNNCellBase, metaclass=RNNCellCompatibilityMeta): where x is the input and h, is the output of the previous time step. Attributes: - gate_fn: activation function used for gates (default: sigmoid) + features: number of output features. + gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 4d0b19d5b4..fea7c75c98 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for working with pjit and partitioned models. +"""Utilities for working with jit and partitioned models. This module introduces `axis_rules`, `logical_to_mesh_axes`, -`logical_to_mesh`, `with_logical_constraint` for appyling pjit -sharding constraints in terms of "logical named axes" rather than -pjit's default mesh axes. +`logical_to_mesh`, `with_logical_constraint` for appyling jit sharding +constraints in terms of "logical named axes" rather than jit's default mesh +axes. Additionally the `LogicallyPartitioned` metadata wrapper is defined as well as the initializer function wrapper `with_logical_partitioning` for @@ -38,8 +38,6 @@ from flax import struct from flax.core import meta -from flax.core.lift import In as ScanIn # pylint: disable=unused-import -from flax.core.lift import Out as ScanOut # pylint: disable=unused-import # Real types and dummy aliases for documentation LogicalRules = Sequence[Tuple[str, Union[str, Tuple[str], None]]] @@ -208,7 +206,7 @@ def logical_to_mesh_sharding( def _global_mesh_defined() -> bool: - """Checks if global xmap/pjit mesh resource environment is defined.""" + """Checks if global xmap/jit mesh resource environment is defined.""" maps_env = maps.thread_resources.env return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison @@ -226,7 +224,7 @@ def _with_sharding_constraint( axis_resources: Optional[jax.sharding.PartitionSpec], mesh: Optional[jax.sharding.Mesh] = None, ): - """Wrapper for lax.with_sharding_constraint, no-op on cpu or outside pjit.""" + """Wrapper for lax.with_sharding_constraint, no-op on cpu or outside jit.""" if jax.devices()[0].platform == 'cpu' or ( not _global_mesh_defined() and mesh is None ): @@ -276,7 +274,7 @@ def with_logical_constraint( mesh: Optional[jax.sharding.Mesh] = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, ): - """Version of pjit's with_sharding_constraint that uses logical axis names.""" + """Version of jit's with_sharding_constraint that uses logical axis names.""" # If no axis binding is set, this is a no-op. if rules is None: rules = _axis_rules.rules diff --git a/flax/linen/stochastic.py b/flax/linen/stochastic.py index 9845f3c3d0..5ed47654aa 100644 --- a/flax/linen/stochastic.py +++ b/flax/linen/stochastic.py @@ -14,7 +14,7 @@ """Stochastic modules.""" -from typing import Optional, Sequence, Union +from typing import Optional, Sequence import jax @@ -32,10 +32,19 @@ class Dropout(Module): """Create a dropout layer. Note: When using :meth:`Module.apply() `, make sure - to include an RNG seed named `'dropout'`. For example:: - - model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': - dropout_rng})` + to include an RNG seed named `'dropout'`. Dropout isn't necessary for + variable initialization. Example:: + + class MLP(nn.Module): + @nn.compact + def __call__(self, x, train): + x = nn.Dense(4)(x) + x = nn.Dropout(0.5, deterministic=not train)(x) + return x + model = MLP() + x = jnp.ones((1, 3)) + variables = model.init(jax.random.key(0), x, train=False) # don't use dropout + model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout Attributes: rate: the dropout probability. (_not_ the keep rate!) diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index adaf7822fc..272ff45668 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -537,20 +537,22 @@ def vmap( RNG must also be shared. Args: - target: a ``Module`` or a function taking a ``Module`` - as its first argument. - variable_axes: the variable collections that are lifted into the - batching transformation. Use `None` to indicate a broadcasted - collection or an integer to map over an axis. - split_rngs: Split PRNG sequences will be different for each index - of the batch dimension. Unsplit PRNGs will be broadcasted. + target: a ``Module`` or a function taking a ``Module`` as its first + argument. + variable_axes: the variable collections that are lifted into the batching + transformation. Use `None` to indicate a broadcasted collection or an + integer to map over an axis. + split_rngs: Split PRNG sequences will be different for each index of the + batch dimension. Unsplit PRNGs will be broadcasted. in_axes: Specifies the mapping of the input arguments (see `jax.vmap`). out_axes: Specifies the mapping of the return value (see `jax.vmap`). - axis_size: Specifies the size of the batch axis. This only needs - to be specified if it cannot be derived from the input arguments. - axis_name: Specifies a name for the batch axis. Can be used together - with parallel reduction primitives (e.g. `jax.lax.pmean`, - `jax.lax.ppermute`, etc.) + axis_size: Specifies the size of the batch axis. This only needs to be + specified if it cannot be derived from the input arguments. + axis_name: Specifies a name for the batch axis. Can be used together with + parallel reduction primitives (e.g. `jax.lax.pmean`, `jax.lax.ppermute`, + etc.). Note, this is only used for pmap and shard map. For SPMD jit, you + do not need to manually synchronize. Just make sure that the axes are + correctly annotated and XLA:SPMD will insert the necessary collectives. methods: If `target` is a `Module`, the methods of `Module` to vmap over. spmd_axis_name: Axis name added to any pjit sharding constraints appearing in `fn`. See also diff --git a/flax/struct.py b/flax/struct.py index f9c299be78..f3c88a7274 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -15,7 +15,7 @@ """Utilities for defining custom classes that can be used with jax transformations.""" import dataclasses -from typing import TypeVar, Callable, Tuple, Union, Any +from typing import TypeVar from . import serialization diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 792671f3ec..1076fb2e9a 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -40,24 +40,13 @@ from jax import monitoring from jax import process_index from jax import tree_util as jtu +from jax.experimental.array_serialization.serialization import get_tensorstore_spec +from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager from jax.experimental.multihost_utils import sync_global_devices -import numpy as np import orbax.checkpoint as ocp _READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec' _WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec' -_IMPORT_GDAM_SUCCESSFUL = False -try: - from jax.experimental.array_serialization.serialization import get_tensorstore_spec - from jax.experimental.array_serialization.serialization import GlobalAsyncCheckpointManager - - _IMPORT_GDAM_SUCCESSFUL = True -except ImportError: - logging.warning( - 'GlobalAsyncCheckpointManager is not imported correctly. ' - 'Checkpointing of GlobalDeviceArrays will not be available.' - 'To use the feature, install tensorstore.' - ) # Single-group reg-exps for int or float numerical substrings. @@ -262,7 +251,7 @@ def _restore_mpas( target: Optional[Any], ckpt_path: str, step: Optional[Union[int, float]], - gda_manager: Optional[Any], + gda_manager: Optional[GlobalAsyncCheckpointManager], allow_partial: bool = False, ): """Restore the multiprocess arrays given the target structure and type.""" @@ -631,6 +620,7 @@ def save_checkpoint( Returns: Filename of saved checkpoint. """ + jax.monitoring.record_event('/jax/flax/checkpoint/save') start_time = time.time() # Make sure all saves are finished before the logic of checking and removing # outdated checkpoints happens. @@ -740,7 +730,7 @@ def save_checkpoint_multiprocess( overwrite: bool = False, keep_every_n_steps: Optional[int] = None, async_manager: Optional[AsyncManager] = None, - gda_manager: Optional[Any] = None, + gda_manager: Optional[GlobalAsyncCheckpointManager] = None, orbax_checkpointer: Optional[ocp.Checkpointer] = None, ) -> str: """Save a checkpoint of the model in multi-process environment. @@ -768,19 +758,20 @@ def save_checkpoint_multiprocess( async_manager: if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly. - gda_manager: required if target contains a JAX GlobalDeviceArray. Type - should be GlobalAsyncCheckpointManager (needs Tensorstore to be imported - correctly). Will save the GDAs to a separate subdirectory with postfix - "_gda" asynchronously. Same as async_manager, this will block subsequent - saves. + gda_manager: required if target contains a JAX GlobalDeviceArray. Will save + the GDAs to a separate subdirectory with postfix "_gda" asynchronously. + Same as async_manager, this will block subsequent saves. orbax_checkpointer: if defined, the save will be done by Orbax In the - future, all Flax checkpointing features will be migrated to Orbax, - and starting to use an `orbax_checkpointer` is recommended. Please - check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers. + future, all Flax checkpointing features will be migrated to Orbax, and + starting to use an `orbax_checkpointer` is recommended. Please check out + the checkpointing guide + (https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#save-checkpoints) + for how to use Orbax checkpointers. Returns: Filename of saved checkpoint. """ + jax.monitoring.record_event('/jax/flax/checkpoint/save') start_time = time.time() # Make sure all saves are finished before the logic of checking and removing # outdated checkpoints happens. @@ -850,7 +841,7 @@ def save_checkpoint_multiprocess( target = serialization.to_state_dict(target) target, mpa_targets = _split_mp_arrays(target) target = serialization.msgpack_serialize(target) - has_mpa = mpa_targets and _IMPORT_GDAM_SUCCESSFUL + has_mpa = bool(mpa_targets) if not overwrite: _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore @@ -989,7 +980,7 @@ def restore_checkpoint( step: Optional[Union[int, float]] = None, prefix: str = 'checkpoint_', parallel: bool = True, - gda_manager: Optional[Any] = None, + gda_manager: Optional[GlobalAsyncCheckpointManager] = None, allow_partial_mpa_restoration: bool = False, orbax_checkpointer: Optional[ocp.Checkpointer] = None, orbax_transforms: Optional[Dict] = None, @@ -1014,9 +1005,8 @@ def restore_checkpoint( prefix: str: name prefix of checkpoint files. parallel: bool: whether to load seekable checkpoints in parallel, for speed. gda_manager: required if checkpoint contains a multiprocess array - (GlobalDeviceArray or jax Array from pjit). Type should be - GlobalAsyncCheckpointManager (needs Tensorstore to be imported correctly). - Will read the arrays from the separate subdirectory with postfix "_gda". + (GlobalDeviceArray or jax Array from pjit). Will read the arrays from the + separate subdirectory with postfix "_gda". allow_partial_mpa_restoration: If true, the given `target` doesn't have to contain all valid multiprocess arrays. As a result, the restored Pytree may have some MPAs not restored correctly. Use this if you cannot provide @@ -1034,6 +1024,7 @@ def restore_checkpoint( returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created. """ + jax.monitoring.record_event('/jax/flax/checkpoint/restore') start_time = time.time() # Make sure any previous work is done before checking files. if orbax_checkpointer and isinstance( @@ -1126,15 +1117,14 @@ def read_chunk(i): checkpoint_contents = fp.read() state_dict = serialization.msgpack_restore(checkpoint_contents) - if _IMPORT_GDAM_SUCCESSFUL: - state_dict = _restore_mpas( - state_dict, - target, - ckpt_path, - step, - gda_manager, - allow_partial_mpa_restoration, - ) + state_dict = _restore_mpas( + state_dict, + target, + ckpt_path, + step, + gda_manager, + allow_partial_mpa_restoration, + ) if target is None: restored_checkpoint = state_dict diff --git a/flax/training/dynamic_scale.py b/flax/training/dynamic_scale.py index 3ed93004c8..b7b329a7f2 100644 --- a/flax/training/dynamic_scale.py +++ b/flax/training/dynamic_scale.py @@ -101,16 +101,20 @@ def value_and_grad( Args: fun: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers. - It should return a scalar (which includes arrays with shape ``()`` - but not arrays with shape ``(1,)`` etc.) + It should return a scalar (which includes arrays with shape ``()`` but + not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function - to be differentiated and the second element is auxiliary data. - Default False. + to be differentiated and the second element is auxiliary data. Default + False. axis_name: If an axis is given the gradients will be averaged across - replicas (default: None). + replicas (default: None). Note, this is only used for pmap and shard + map. For SPMD jit, you do not need to manually synchronize. Just make + sure that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. + Returns: A function that takes the same arguments as `fun` and returns a DynamicScaleResult diff --git a/flax/training/train_state.py b/flax/training/train_state.py index 81c7a1ca4f..9570ebe51b 100644 --- a/flax/training/train_state.py +++ b/flax/training/train_state.py @@ -16,6 +16,7 @@ from flax import core from flax import struct +from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT import optax @@ -71,8 +72,27 @@ def apply_gradients(self, *, grads, **kwargs): and `opt_state` updated by applying `grads`, and additional attributes replaced as specified by `kwargs`. """ - updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) - new_params = optax.apply_updates(self.params, updates) + if OVERWRITE_WITH_GRADIENT in grads: + grads_with_opt = grads['params'] + params_with_opt = self.params['params'] + else: + grads_with_opt = grads + params_with_opt = self.params + + updates, new_opt_state = self.tx.update( + grads_with_opt, self.opt_state, params_with_opt + ) + new_params_with_opt = optax.apply_updates(params_with_opt, updates) + + # As implied by the OWG name, the gradients are used directly to update the + # parameters. + if OVERWRITE_WITH_GRADIENT in grads: + new_params = { + 'params': new_params_with_opt, + OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT] + } + else: + new_params = new_params_with_opt return self.replace( step=self.step + 1, params=new_params, @@ -83,45 +103,11 @@ def apply_gradients(self, *, grads, **kwargs): @classmethod def create(cls, *, apply_fn, params, tx, **kwargs): """Creates a new instance with `step=0` and initialized `opt_state`.""" - opt_state = tx.init(params) - return cls( - step=0, - apply_fn=apply_fn, - params=params, - tx=tx, - opt_state=opt_state, - **kwargs, - ) - -class Fp8TrainState(TrainState): - """Customized train state for Fp8.""" - - def apply_gradients(self, *, grads, **kwargs): - assert 'fp8_params' in grads - updates, new_opt_state = self.tx.update(grads['params'], self.opt_state, - self.params['params']) - new_non_fp8_params = optax.apply_updates(self.params['params'], updates) - - # self.param is structured as - # {'param': {'kernel:...,'}, 'fp8_params': {...}}. For the fp8 variables - # in the fp8-params collection, we will simply replace them with their - # grads, because their grads are actually new values defined in the - # custom_vjp functions. - new_params = {'params': new_non_fp8_params, - 'fp8_params': grads['fp8_params']} - - return self.replace( - step=self.step + 1, - params=new_params, - opt_state=new_opt_state, - **kwargs, + # We exclude OWG params when present because they do not need opt states. + params_with_opt = ( + params['params'] if OVERWRITE_WITH_GRADIENT in params else params ) - - @classmethod - def create(cls, *, apply_fn, params, tx, **kwargs): - assert 'fp8_params' in params - opt_state = tx.init(params['params']) - + opt_state = tx.init(params_with_opt) return cls( step=0, apply_fn=apply_fn, diff --git a/flax/traverse_util.py b/flax/traverse_util.py index e6d409ac1c..4ef0768c44 100644 --- a/flax/traverse_util.py +++ b/flax/traverse_util.py @@ -43,7 +43,7 @@ import abc import copy import dataclasses -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Tuple import warnings import jax diff --git a/pyproject.toml b/pyproject.toml index 6d257c4c00..a1dd3adb6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,8 +147,16 @@ filterwarnings = [ "ignore:.*module 'sre_constants' is deprecated.*:DeprecationWarning", # DeprecationWarning: jax.random.KeyArray is deprecated. "ignore:.*jax.random.KeyArray is deprecated.*:DeprecationWarning", + # DeprecationWarning: SelfAttention will be deprecated soon. + "ignore:.*SelfAttention will be deprecated soon.*:DeprecationWarning", + # DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead. + "ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning", + # DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed + "ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning", # DeprecationWarning: ml_dtypes.float8_e4m3b11 is deprecated. "ignore:.*ml_dtypes.float8_e4m3b11 is deprecated.*:DeprecationWarning", + # DeprecationWarning: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any]. (chex, recheck by Nov 2023) + "ignore:.*jax.core.Shape is deprecated.*:DeprecationWarning", ] [tool.coverage.report] @@ -159,3 +167,14 @@ pyink-indentation = 2 pyink-use-majority-quotes = true line-length = 80 preview = true + +[tool.ruff] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = ["F401"] +ignore = [] +# Allow fix for all enabled rules (when `--fix`) is provided. +# Full list of rules: https://docs.astral.sh/ruff/rules/ +fixable = ["F401"] +unfixable = [] +# Exclude a variety of commonly ignored directories. +exclude = ["__init__.py", "activation.py", "partitioning.py", "variables.py"] diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index 74cc245c7f..51ce5911af 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -14,8 +14,7 @@ import operator from flax import errors -from flax.core import Scope, init, apply, lift, nn, FrozenDict, unfreeze, copy -from flax.configurations import temp_flip_flag +from flax.core import init, apply, lift, nn, FrozenDict, copy import jax from jax import random diff --git a/tests/core/core_scope_test.py b/tests/core/core_scope_test.py index 86634a1c8a..6f8190d7f6 100644 --- a/tests/core/core_scope_test.py +++ b/tests/core/core_scope_test.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from flax import errors from flax.core import Scope, scope, freeze, lazy_init, init, apply, nn from flax.core.scope import LazyRng from flax.configurations import temp_flip_flag import jax -from jax import config as jax_config from jax import random from jax import numpy as jnp diff --git a/tests/core/design/core_auto_encoder_test.py b/tests/core/design/core_auto_encoder_test.py index 7eff74f4a5..8c462da5b0 100644 --- a/tests/core/design/core_auto_encoder_test.py +++ b/tests/core/design/core_auto_encoder_test.py @@ -20,12 +20,10 @@ import jax from jax import numpy as jnp, random -from flax import struct -from jax.scipy.linalg import expm -from dataclasses import dataclass, InitVar -from typing import Any, Callable, Sequence, NamedTuple, Any +from dataclasses import dataclass +from typing import Callable def mlp(scope: Scope, x: Array, hidden: int, out: int): diff --git a/tests/core/design/core_big_resnets_test.py b/tests/core/design/core_big_resnets_test.py index 01d38362bf..d47f766afc 100644 --- a/tests/core/design/core_big_resnets_test.py +++ b/tests/core/design/core_big_resnets_test.py @@ -18,10 +18,10 @@ import numpy as np -from flax.core import Scope, Array, init, apply, unfreeze, lift, nn +from flax.core import Scope, Array, init, unfreeze, lift, nn import jax -from jax import lax, random, numpy as jnp +from jax import random, numpy as jnp default_norm = partial(nn.batch_norm) diff --git a/tests/core/design/core_flow_test.py b/tests/core/design/core_flow_test.py index 872674745b..d72b76a52a 100644 --- a/tests/core/design/core_flow_test.py +++ b/tests/core/design/core_flow_test.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, Sequence, NamedTuple, Any +from typing import Any, Sequence, Any from absl.testing import absltest diff --git a/tests/early_stopping_test.py b/tests/early_stopping_test.py index 92b74fb9ed..ea6ce47ddc 100644 --- a/tests/early_stopping_test.py +++ b/tests/early_stopping_test.py @@ -14,13 +14,10 @@ """Tests for flax.training.early_stopping.""" -import copy -import os from absl.testing import absltest from flax.training import early_stopping import jax -from jax import test_util as jtu # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 2556b292c2..ef49d52b80 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -17,10 +17,10 @@ from absl.testing import absltest from absl.testing import parameterized +from flax import errors from flax import linen as nn from flax import jax_utils from flax.core import pop -from flax.configurations import temp_flip_flag import jax from jax import lax @@ -67,7 +67,6 @@ def test_dtype_infer(self): def test_multihead_encoder_decoder_attention(self): rng = random.key(0) q = jnp.ones((4, 2, 3, 5)) - kv = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, @@ -75,7 +74,7 @@ def test_multihead_encoder_decoder_attention(self): bias_init=initializers.zeros, deterministic=False, ) - y, _ = sa_module.init_with_output(rng, q, kv) + y, _ = sa_module.init_with_output(rng, q) self.assertEqual(y.shape, q.shape) def test_multihead_self_attention_w_dropout(self): @@ -91,9 +90,40 @@ def test_multihead_self_attention_w_dropout(self): ) rng1, rng2 = random.split(rng) rngs = {'params': rng1, 'dropout': rng2} - y, _ = sa_module.init_with_output(rngs, x, x) + y, _ = sa_module.init_with_output(rngs, x) self.assertEqual(y.shape, x.shape) + def test_multihead_self_attention_explicit_dropout(self): + class Foo(nn.Module): + attention_kwargs: dict + @nn.compact + def __call__(self, x, dropout_rng=None): + a = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, x, dropout_rng=dropout_rng) + b = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, x, dropout_rng=dropout_rng) + return a, b + + module = Foo( + dict( + num_heads=8, + qkv_features=16, + kernel_init=initializers.ones, + bias_init=initializers.zeros, + dropout_rate=0.5, + deterministic=False, + ) + ) + rng1, rng2, rng3 = random.split(random.key(0), 3) + x = jnp.ones((4, 2, 3, 5)) + rngs = {'params': rng1, 'dropout': rng2} + v = module.init(rngs, x) + a, b = module.apply(v, x, rngs=rngs) + self.assertTrue(not (a == b).all()) + a, b = module.apply(v, x, rngs=rngs, dropout_rng=rng3) + self.assertTrue((a == b).all()) + a, b = module.apply(v, x, dropout_rng=rng3) + self.assertTrue((a == b).all()) + self.assertTrue(a.shape == b.shape == x.shape) + def test_multihead_self_attention_w_dropout_disabled(self): rng = random.key(0) x = jnp.ones((4, 2, 3, 5)) @@ -108,11 +138,11 @@ def test_multihead_self_attention_w_dropout_disabled(self): rng1, rng2, rng3, rng4 = random.split(rng, 4) rngs1 = {'params': rng1, 'dropout': rng2} rngs2 = {'params': rng3, 'dropout': rng4} - y1, vs = sa_module0.init_with_output(rngs1, x, x) - y2, _ = sa_module0.init_with_output(rngs2, x, x) + y1, vs = sa_module0.init_with_output(rngs1, x) + y2, _ = sa_module0.init_with_output(rngs2, x) np.testing.assert_allclose(y1, y2) - y3 = sa_module0.apply(vs, x, x, rngs=rngs1) - y4 = sa_module0.apply(vs, x, x, rngs=rngs2) + y3 = sa_module0.apply(vs, x, rngs=rngs1) + y4 = sa_module0.apply(vs, x, rngs=rngs2) np.testing.assert_allclose(y3, y4) sa_module1 = nn.MultiHeadDotProductAttention( num_heads=8, @@ -121,8 +151,8 @@ def test_multihead_self_attention_w_dropout_disabled(self): bias_init=initializers.zeros, dropout_rate=0.0, ) - y5 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs1) - y6 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs2) + y5 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs1) + y6 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs2) np.testing.assert_allclose(y5, y6) sa_module2 = nn.MultiHeadDotProductAttention( num_heads=8, @@ -131,8 +161,8 @@ def test_multihead_self_attention_w_dropout_disabled(self): bias_init=initializers.zeros, dropout_rate=0.5, ) - y7 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs1) - y8 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs2) + y7 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs1) + y8 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs2) np.testing.assert_allclose(y7, y8) def test_causal_mask_1d(self): @@ -204,11 +234,11 @@ def test_autoregresive_receptive_field_1d(self): deterministic=False, ) - initial_vars = module.init(rng1, inputs, inputs) + initial_vars = module.init(rng1, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) def model_loss(inputs, pos): - out = module.apply(initial_vars, inputs, inputs, causal_mask) + out = module.apply(initial_vars, inputs, mask=causal_mask) assert out.shape == input_shape assert len(out.shape) == 3 return out[0, pos, :].sum() @@ -234,6 +264,75 @@ def get_receptive_field_1d(pos): 'autoregressive self-attention.' ) + def test_multihead_self_attention_equality(self): + rng = random.key(0) + q = jnp.ones((4, 2, 3, 5)) + module_kwargs = {'num_heads': 8, + 'qkv_features': 16, + 'kernel_init': initializers.ones, + 'bias_init': initializers.zeros, + 'deterministic': False} + sa_module0 = nn.MultiHeadDotProductAttention(**module_kwargs) + sa_module1 = nn.SelfAttention(**module_kwargs) + y0, v0 = sa_module0.init_with_output(rng, q) + with self.assertWarnsRegex(DeprecationWarning, 'SelfAttention will be deprecated soon.'): + y1, v1 = sa_module1.init_with_output(rng, q) + self.assertTrue((y0 == y1).all()) + self.assertTrue(jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v0, v1))) + + def test_multihead_kv_args(self): + key1, key2 = random.split(random.key(0), 2) + query = random.uniform(key1, (3, 5)) + key_value = random.uniform(key1, (9, 5)) + module = nn.MultiHeadDotProductAttention( + num_heads=8, + qkv_features=16, + kernel_init=initializers.ones, + bias_init=initializers.zeros, + deterministic=False, + ) + y0, v0 = module.init_with_output(key2, query, inputs_k=key_value, inputs_v=key_value) + y1, v1 = module.init_with_output(key2, query, inputs_k=key_value) + with self.assertWarnsRegex(DeprecationWarning, 'The inputs_kv arg will be deprecated soon.'): + y2, v2 = module.init_with_output(key2, query, inputs_kv=key_value) + self.assertTrue((y0 == y1).all() and (y1 == y2).all()) + self.assertTrue( + jax.tree_util.tree_all( + jax.tree_map(lambda x, y, z: (x == y).all() and (y == z).all(), + v0, v1, v2))) + + with self.assertRaisesRegex(ValueError, '`inputs_k` cannot be None if `inputs_v` is not None.'): + y3, v3 = module.init_with_output(key2, query, inputs_v=key_value) + with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'): + y3, v3 = module.init_with_output(key2, query, inputs_kv=key_value, inputs_v=key_value) + with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'): + y3, v3 = module.init_with_output(key2, query, key_value, key_value, inputs_kv=key_value) + + def test_multihead_mask_warning(self): + rng = random.key(0) + rng1, rng2 = random.split(rng, num=2) + + length = 10 + dim = 1 + num_heads = 1 + input_shape = (1, length, dim) + query = key = random.normal(rng2, input_shape) + + module = nn.MultiHeadDotProductAttention( + num_heads=num_heads, + kernel_init=jax.nn.initializers.ones, + deterministic=False, + ) + + initial_vars = module.init(rng1, query, key) + causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) + + module.apply(initial_vars, query, key, mask=causal_mask) + with self.assertWarnsRegex(DeprecationWarning, + "the function signature of MultiHeadDotProductAttention's `__call__` method has changed"): + with self.assertRaises(errors.ScopeParamShapeError): + module.apply(initial_vars, query, key, causal_mask) + if __name__ == '__main__': absltest.main() diff --git a/tests/linen/linen_dtypes_test.py b/tests/linen/linen_dtypes_test.py index 7233486c5e..f878960b92 100644 --- a/tests/linen/linen_dtypes_test.py +++ b/tests/linen/linen_dtypes_test.py @@ -14,12 +14,9 @@ """Tests for flax.linen.dtypes.""" -import functools -from multiprocessing.sharedctypes import Value from absl.testing import absltest -from flax import linen as nn from flax.linen import dtypes import jax diff --git a/tests/linen/linen_linear_test.py b/tests/linen/linen_linear_test.py index 715e7a7c8e..bbfacd325e 100644 --- a/tests/linen/linen_linear_test.py +++ b/tests/linen/linen_linear_test.py @@ -15,16 +15,11 @@ """Tests for flax.linen.linear.""" import functools -from multiprocessing.sharedctypes import Value -from typing import Callable, Optional from absl.testing import absltest from absl.testing import parameterized -import optax - from flax import linen as nn -from flax.training import train_state import jax from jax import random @@ -1020,142 +1015,6 @@ def __call__(self, x): ) self.assertEqual(y.shape, (2, 8, 6)) - def test_fp8_dot_general_cls_injection(self): - # Used to cast the inputs to be representable in FP8, so that the difference - # of the results from the original gemm and fp8 gemm is small. - cast_to_representable = functools.partial(nn.fp8_quantize_dequantize, - scale=jnp.ones((1,)), - compute_dtype=jnp.float32) - - init_key, random_key = jax.random.split( - jax.random.PRNGKey(seed=123), 2) - - x = jax.random.uniform(random_key, (16, 32)) - x = cast_to_representable(x, jnp.float8_e4m3fn) - dy = jax.random.uniform(random_key, (16, 64)) - dy = cast_to_representable(dy, jnp.float8_e5m2) - def run(fp8_injection, expected_shapes): - p = nn.DenseGeneral(features=64, name='dense') - if fp8_injection: - p.dot_general_cls=nn.Fp8DenseGeneralOp - y, initial_vars = p.init_with_output(init_key, x) - var_shapes = jax.tree_util.tree_map(jnp.shape, initial_vars) - self.assertEqual(var_shapes, expected_shapes) - - def _train(variables, x): - y = p.apply(variables, x) - loss = y * dy - return jnp.mean(loss) - train_fn = jax.jit(jax.value_and_grad(_train, argnums=[0, 1])) - outputs, grads = train_fn(initial_vars, x) - return outputs, grads - - expected_shapes_original = { - 'params': {'kernel': (32, 64), 'bias': (64,)}, - } - expected_shapes_new = { - 'params': {'kernel': (32, 64), 'bias': (64,)}, - 'fp8_params': { - 'Fp8DenseGeneralOp_0': {'input_amax_history': (1024,), - 'kernel_amax_history': (1024,), - 'output_grad_amax_history': (1024,), - 'input_scale': (1,), - 'kernel_scale': (1,), - 'output_grad_scale': (1,), }}, - } - - output1a, output1b = run(False, expected_shapes_original) - output2a, output2b = run(True, expected_shapes_new) - dw1, dw2 = output1b[0]['params']['kernel'], output2b[0]['params']['kernel'] - dx1, dx2 = output1b[1], output2b[1] - - np.testing.assert_allclose(output1a, output2a, atol=1e-02) - np.testing.assert_allclose(dw1, dw2, atol=1e-04) - np.testing.assert_allclose(dx1, dx2, atol=1e-04) - - def test_fp8_with_train_state(self): - x = random.uniform(random.PRNGKey(1), (16, 16), dtype=jnp.float32) - dense = nn.DenseGeneral(features=32, use_bias=True, - dot_general_cls=nn.Fp8DenseGeneralOp) - key = random.PRNGKey(0) - variables = dense.init(key, x) - - opt = optax.adam(learning_rate=.1) - state = train_state.Fp8TrainState.create(params=variables, tx=opt, - apply_fn=dense.apply) - - def roll_and_update(amax_h, update): - return jnp.roll(amax_h, shift=-1, axis=0).at[0].set(update) - - def _train_loss(state, x, dy): - def loss_fn(vars): - y = state.apply_fn(vars, x) - loss = y * dy.astype(y.dtype) - return jnp.sum(loss) - - grad_fn = jax.grad(loss_fn) - grads = grad_fn(state.params) - - state = state.apply_gradients(grads=grads) - return state - - train_fn = jax.jit(_train_loss) - - amax_history_x = jnp.zeros((1024, )) - amax_history_k = jnp.zeros((1024, )) - amax_history_dy = jnp.zeros((1024, )) - scale_x = jnp.ones(()) - scale_k = jnp.ones(()) - scale_dy = jnp.ones(()) - fp8_e4m3_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) - fp8_e5m2_max = jnp.finfo(jnp.float8_e5m2).max.astype(jnp.float32) - for _ in range(5): - x = random.normal(random.PRNGKey(1), (16, 16), dtype=jnp.float32) - dy = random.normal(random.PRNGKey(1), (16, 32), dtype=jnp.float32) - - amax_history_x = roll_and_update(amax_history_x, jnp.max(jnp.abs(x))) - amax_history_k = roll_and_update( - amax_history_k, - jnp.max(jnp.abs(state.params['params']['kernel']))) - amax_history_dy = roll_and_update(amax_history_dy, jnp.max(jnp.abs(dy))) - - amax_from_history_x = jnp.max(amax_history_x, axis=0) - amax_from_history_k = jnp.max(amax_history_k, axis=0) - amax_from_history_dy = jnp.max(amax_history_dy, axis=0) - scale_x = nn.fp8_compute_scale(amax_from_history_x, scale_x, - fp8_e4m3_max) - scale_k = nn.fp8_compute_scale(amax_from_history_k, scale_k, fp8_e4m3_max) - scale_dy = nn.fp8_compute_scale(amax_from_history_dy, scale_dy, - fp8_e5m2_max) - - state = train_fn(state, x, dy) - - rtol, atol = 0.001, 0.001 - fp8_vars = state.params['fp8_params'] - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0']['input_amax_history'], - amax_history_x, rtol=rtol, atol=atol) - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0']['kernel_amax_history'], - amax_history_k, rtol=rtol, atol=atol) - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0'] - ['output_grad_amax_history'], - amax_history_dy, rtol=rtol, atol=atol) - - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0'] - ['input_scale'][0], - scale_x) - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0'] - ['kernel_scale'][0], - scale_k) - np.testing.assert_allclose( - fp8_vars['Fp8DenseGeneralOp_0'] - ['output_grad_scale'][0], - scale_dy) - def test_non_final_axes(self): class Foo(nn.Module): diff --git a/tests/linen/linen_meta_test.py b/tests/linen/linen_meta_test.py index eaebda2b64..9a27ea328d 100644 --- a/tests/linen/linen_meta_test.py +++ b/tests/linen/linen_meta_test.py @@ -20,7 +20,6 @@ from jax import numpy as jnp from jax import random from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit from jax.sharding import Mesh from jax.sharding import PartitionSpec diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index e894b7fa4a..adbedb0a90 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -42,8 +42,7 @@ from flax import errors from flax import linen as nn from flax import struct -from flax.configurations import temp_flip_flag -from flax.core import FrozenDict, Scope, freeze, tracers +from flax.core import FrozenDict, Scope, freeze from flax.linen import compact import jax from jax import random @@ -821,7 +820,7 @@ def __call__(self, x): precision = None kernel_init = init bias_init = zeros - dot_general = dot_general + dot_general = None dot_general_cls = None ) Dense_1 = Dense( @@ -833,7 +832,7 @@ def __call__(self, x): precision = None kernel_init = init bias_init = zeros - dot_general = dot_general + dot_general = None dot_general_cls = None ) )""" diff --git a/tests/linen/linen_recurrent_test.py b/tests/linen/linen_recurrent_test.py index 74c900128b..a8ec60c4b0 100644 --- a/tests/linen/linen_recurrent_test.py +++ b/tests/linen/linen_recurrent_test.py @@ -19,10 +19,7 @@ import jax import jax.numpy as jnp import numpy as np -from flax import errors from flax import linen as nn -import pytest -import einops from flax.linen.recurrent import flip_sequences # Parse absl flags test_srcdir and test_tmpdir. diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 4d8ea322ad..135782e629 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -15,19 +15,21 @@ """Tests for flax.linen.""" import copy -from absl.testing import absltest, parameterized +import functools from typing import Any +from absl.testing import absltest +from absl.testing import parameterized from flax import ids from flax import linen as nn +from flax.linen import fp8_ops from flax.training import train_state - import jax from jax import random -from jax.nn import initializers import jax.numpy as jnp -import optax import numpy as np +import optax + # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() @@ -79,25 +81,6 @@ def test_avg_pool_no_batch(self, count_include_pad): ]).reshape((3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad) - @parameterized.parameters( - {'count_include_pad': True}, {'count_include_pad': False} - ) - def test_avg_pool_padding_same(self, count_include_pad): - x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1)) - pool = lambda x: nn.avg_pool( - x, (2, 2), padding='SAME', count_include_pad=count_include_pad - ) - y = pool(x) - if count_include_pad: - expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape( - (1, 2, 2, 1) - ) - else: - expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape( - (1, 2, 2, 1) - ) - np.testing.assert_allclose(y, expected_y) - def test_max_pool(self): x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) pool = lambda x: nn.max_pool(x, (2, 2)) @@ -149,51 +132,74 @@ def test_pooling_no_batch_dims(self): class NormalizationTest(parameterized.TestCase): - def test_batch_norm(self): + @parameterized.parameters({'test_mask': True}, {'test_mask': False}) + def test_batch_norm(self, test_mask): rng = random.key(0) - key1, key2 = random.split(rng) + key1, key2, key3 = random.split(rng, 3) x = random.normal(key1, (4, 3, 2)) + if test_mask: + m = random.randint( + key2, (4, 3, 1), minval=0, maxval=2, dtype=jnp.int32 + ).astype(jnp.bool_) + x = jnp.where(m, x, jnp.ones_like(x) * jnp.nan) + else: + m = None model_cls = nn.BatchNorm(momentum=0.9, use_running_average=False) - y, initial_params = model_cls.init_with_output(key2, x) + y, initial_params = model_cls.init_with_output(key3, x, mask=m) - mean = y.mean((0, 1)) - var = y.var((0, 1)) + mean = y.mean((0, 1), where=m) + var = y.var((0, 1), where=m) np.testing.assert_allclose(mean, np.array([0.0, 0.0]), atol=1e-4) np.testing.assert_allclose(var, np.array([1.0, 1.0]), rtol=1e-4) - - y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats']) + _, vars_out = model_cls.apply( + initial_params, x, mutable=['batch_stats'], mask=m + ) ema = vars_out['batch_stats'] np.testing.assert_allclose( - ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4 + ema['mean'], 0.1 * x.mean((0, 1), keepdims=False, where=m), atol=1e-4 ) np.testing.assert_allclose( - ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4 + ema['var'], + 0.9 + 0.1 * x.var((0, 1), keepdims=False, where=m), + rtol=1e-4, ) - def test_batch_norm_complex(self): + @parameterized.parameters({'test_mask': True}, {'test_mask': False}) + def test_batch_norm_complex(self, test_mask): rng = random.key(0) - key1, key2 = random.split(rng) + key1, key2, key3 = random.split(rng, 3) x = random.normal(key1, (4, 3, 2), dtype=jnp.complex64) + if test_mask: + m = random.randint( + key2, (4, 3, 1), minval=0, maxval=2, dtype=jnp.int32 + ).astype(jnp.bool_) + x = jnp.where(m, x, jnp.ones_like(x) * jnp.nan) + else: + m = None model_cls = nn.BatchNorm( momentum=0.9, use_running_average=False, dtype=jnp.complex64 ) - y, initial_params = model_cls.init_with_output(key2, x) + y, initial_params = model_cls.init_with_output(key3, x, mask=m) - mean = y.mean((0, 1)) - var = y.var((0, 1)) + mean = y.mean((0, 1), where=m) + var = y.var((0, 1), where=m) np.testing.assert_allclose(mean, np.array([0.0, 0.0]), atol=1e-4) np.testing.assert_allclose(var, np.array([1.0, 1.0]), rtol=1e-4) self.assertEqual(mean.dtype, jnp.complex64) - y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats']) + _, vars_out = model_cls.apply( + initial_params, x, mutable=['batch_stats'], mask=m + ) ema = vars_out['batch_stats'] np.testing.assert_allclose( - ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4 + ema['mean'], 0.1 * x.mean((0, 1), keepdims=False, where=m), atol=1e-4 ) np.testing.assert_allclose( - ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4 + ema['var'], + 0.9 + 0.1 * x.var((0, 1), keepdims=False, where=m), + rtol=1e-4, ) @parameterized.parameters( @@ -293,14 +299,42 @@ def __call__(self, x): key = random.key(0) model = Foo() x = random.normal(random.key(1), (2, 4)) - (y1, y2), variables = model.init_with_output(key, x) + (y1, y2), _ = model.init_with_output(key, x) np.testing.assert_allclose(y1, y2, rtol=1e-4) + @parameterized.parameters( + { + 'model_index': 0, + 'key_paths': {'Dense_1/kernel/u', 'Dense_1/kernel/sigma'}, + }, + { + 'model_index': 1, + 'key_paths': {'Conv_0/kernel/u', 'Conv_0/kernel/sigma'}, + }, + { + 'model_index': 2, + 'key_paths': { + 'MultiHeadDotProductAttention_0/key/bias/u', + 'MultiHeadDotProductAttention_0/key/kernel/u', + 'MultiHeadDotProductAttention_0/out/kernel/u', + 'MultiHeadDotProductAttention_0/query/bias/u', + 'MultiHeadDotProductAttention_0/query/kernel/u', + 'MultiHeadDotProductAttention_0/value/bias/u', + 'MultiHeadDotProductAttention_0/value/kernel/u', + 'MultiHeadDotProductAttention_0/key/bias/sigma', + 'MultiHeadDotProductAttention_0/key/kernel/sigma', + 'MultiHeadDotProductAttention_0/out/kernel/sigma', + 'MultiHeadDotProductAttention_0/query/bias/sigma', + 'MultiHeadDotProductAttention_0/query/kernel/sigma', + 'MultiHeadDotProductAttention_0/value/bias/sigma', + 'MultiHeadDotProductAttention_0/value/kernel/sigma', + }, + }, + ) def test_spectral_norm_train( - self, + self, model_index, key_paths ): class FooDense(nn.Module): - @nn.compact def __call__(self, x, train): x = nn.Dense(8)(x) @@ -309,7 +343,6 @@ def __call__(self, x, train): return x class FooConv(nn.Module): - @nn.compact def __call__(self, x, train): x = nn.Dense(9)(x) @@ -322,7 +355,6 @@ def __call__(self, x, train): return x class FooAttention(nn.Module): - @nn.compact def __call__(self, x, train): a = nn.Dense(4)(x) @@ -337,123 +369,307 @@ def __call__(self, x, train): x = random.normal(key1, (1, 4)) y = random.normal(key2, (1, 4)) - for model_cls, var_paths in ( - (FooDense, ('Dense_1/kernel/',)), - (FooConv, ('Conv_0/kernel/',)), - ( - FooAttention, - ( - 'MultiHeadDotProductAttention_0/key/bias/', - 'MultiHeadDotProductAttention_0/key/kernel/', - 'MultiHeadDotProductAttention_0/out/kernel/', - 'MultiHeadDotProductAttention_0/query/bias/', - 'MultiHeadDotProductAttention_0/query/kernel/', - 'MultiHeadDotProductAttention_0/value/bias/', - 'MultiHeadDotProductAttention_0/value/kernel/', - ), - ), - ): - variables = model_cls().init(key3, x, train=False) - params, batch_stats = variables['params'], variables['batch_stats'] - for var_path in var_paths: - self.assertTrue(var_path + 'u' in batch_stats['SpectralNorm_0'].keys()) - self.assertTrue( - var_path + 'sigma' in batch_stats['SpectralNorm_0'].keys() + model_cls = (FooDense, FooConv, FooAttention)[model_index] + variables = model_cls().init(key3, x, train=False) + params, batch_stats = variables['params'], variables['batch_stats'] + self.assertEqual(key_paths, batch_stats['SpectralNorm_0'].keys()) + + class TrainState(train_state.TrainState): + batch_stats: Any + + state = TrainState.create( + apply_fn=model_cls().apply, + params=params, + batch_stats=batch_stats, + tx=optax.adam(1e-3), + ) + + @jax.jit + def train_step(state, batch): + def loss_fn(params): + logits, updates = state.apply_fn( + {'params': params, 'batch_stats': state.batch_stats}, + x=batch['image'], + train=True, + mutable=['batch_stats'], ) + loss = jnp.mean( + optax.l2_loss(predictions=logits, targets=batch['label']) + ) + return loss, updates - class TrainState(train_state.TrainState): - batch_stats: Any + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, updates), grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + state = state.replace(batch_stats=updates['batch_stats']) + return state, loss - state = TrainState.create( - apply_fn=model_cls().apply, - params=params, - batch_stats=batch_stats, - tx=optax.adam(1e-3), - ) + prev_loss = float('inf') + for _ in range(10): + state, loss = train_step(state, {'image': x, 'label': y}) + self.assertLess(loss, prev_loss) + prev_loss = loss - @jax.jit - def train_step(state, batch): - def loss_fn(params): - logits, updates = state.apply_fn( - {'params': params, 'batch_stats': state.batch_stats}, - x=batch['image'], - train=True, - mutable=['batch_stats'], - ) - loss = jnp.mean( - optax.l2_loss(predictions=logits, targets=batch['label']) - ) - return loss, updates - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (loss, updates), grads = grad_fn(state.params) - state = state.apply_gradients(grads=grads) - state = state.replace(batch_stats=updates['batch_stats']) - return state, loss - - prev_loss = float('inf') - for _ in range(10): - state, loss = train_step(state, {'image': x, 'label': y}) - self.assertTrue(loss < prev_loss) - prev_loss = loss - - def test_spectral_norm_sigma(self): - for n_steps, update_stats, result in ( - (1, True, 4.0), - (3, True, 4.0), - (10, True, 4.0), - (1, False, 1.0), - ): + @parameterized.parameters( + {'n_steps': 1, 'update_stats': True, 'result': 4.0}, + {'n_steps': 3, 'update_stats': True, 'result': 4.0}, + {'n_steps': 10, 'update_stats': True, 'result': 4.0}, + {'n_steps': 1, 'update_stats': False, 'result': 1.0} + ) + def test_spectral_norm_sigma(self, n_steps, update_stats, result): + class Foo(nn.Module): - class Foo(nn.Module): + @nn.compact + def __call__(self, x, train): + x = nn.SpectralNorm(nn.Dense(8, use_bias=False), n_steps=n_steps)( + x, update_stats=train + ) + return x - @nn.compact - def __call__(self, x, train): - x = nn.SpectralNorm(nn.Dense(8, use_bias=False), n_steps=n_steps)( - x, update_stats=train - ) - return x - - x = jnp.ones((1, 8)) - model_cls = Foo() - variables = model_cls.init(random.PRNGKey(0), x, train=False) - params, batch_stats = variables['params'], variables['batch_stats'] - params = jax.tree_map(lambda x: 4 * jnp.eye(*x.shape), params) - logits, updates = model_cls.apply( - {'params': params, 'batch_stats': batch_stats}, - x=x, - train=update_stats, - mutable=True, + x = jnp.ones((1, 8)) + model_cls = Foo() + variables = model_cls.init(random.PRNGKey(0), x, train=False) + params, batch_stats = variables['params'], variables['batch_stats'] + params = jax.tree_map(lambda x: 4 * jnp.eye(*x.shape), params) + _, updates = model_cls.apply( + {'params': params, 'batch_stats': batch_stats}, + x=x, + train=update_stats, + mutable=True, + ) + np.testing.assert_allclose( + updates['batch_stats']['SpectralNorm_0']['Dense_0/kernel/sigma'], + result, + atol=1e-3, + ) + + @parameterized.parameters( + {'error_on_non_matrix': True}, + {'error_on_non_matrix': False} + ) + def test_spectral_norm_3d_tensor(self, error_on_non_matrix): + class Foo(nn.Module): + + @nn.compact + def __call__(self, x, train): + x = nn.SpectralNorm( + nn.DenseGeneral((3, 4), use_bias=False), + error_on_non_matrix=error_on_non_matrix, + )(x, update_stats=train) + return x + + x = jnp.ones((1, 2)) + model_cls = Foo() + + if error_on_non_matrix: + with self.assertRaisesRegex( + ValueError, 'Input is 3D but error_on_non_matrix is True' + ): + _ = model_cls.init(random.PRNGKey(0), x, train=False) + else: + _ = model_cls.init(random.PRNGKey(0), x, train=False) + + @parameterized.parameters( + {'feature_axes': -1, 'reduction_axes': 0, 'variable_filter': {'kernel'}}, + {'feature_axes': 0, 'reduction_axes': 1, 'variable_filter': {'kernel'}}, + { + 'feature_axes': (0, 1), + 'reduction_axes': (), + 'variable_filter': {'kernel'}, + }, + { + 'feature_axes': (), + 'reduction_axes': (0, 1), + 'variable_filter': {'kernel'}, + }, + { + 'feature_axes': None, + 'reduction_axes': (0, 1), + 'variable_filter': {'kernel'}, + }, + {'feature_axes': 0, 'reduction_axes': (), 'variable_filter': {'bias'}}, + {'feature_axes': (), 'reduction_axes': -1, 'variable_filter': {'bias'}}, + ) + def test_manual_weight_norm( + self, feature_axes, reduction_axes, variable_filter + ): + class Foo(nn.Module): + + @nn.compact + def __call__(self, x): + return nn.WeightNorm(nn.Dense(2, bias_init=nn.initializers.normal()), + feature_axes=feature_axes, + variable_filter=variable_filter)(x) + + key1, key2 = jax.random.split(jax.random.key(1)) + x = jax.random.normal(key1, (1, 3)) + module = Foo() + v = module.init(key2, x) + v = jax.tree_map(lambda x: x + 0.5, v) + out = module.apply(v, x) + + kernel = v['params']['Dense_0']['kernel'] + if 'kernel' in variable_filter: + kernel /= jnp.sqrt( + jnp.sum(kernel**2, axis=reduction_axes, keepdims=True) ) - np.testing.assert_allclose( - updates['batch_stats']['SpectralNorm_0']['Dense_0/kernel/sigma'], - result, - atol=1e-3, + kernel_scale = jnp.expand_dims( + v['params']['WeightNorm_0']['Dense_0/kernel/scale'], + axis=reduction_axes, + ) + else: + kernel_scale = 1 + bias = v['params']['Dense_0']['bias'] + if 'bias' in variable_filter: + bias /= jnp.sqrt(jnp.sum(bias**2, axis=reduction_axes, keepdims=True)) + bias_scale = jnp.expand_dims( + v['params']['WeightNorm_0']['Dense_0/bias/scale'], axis=reduction_axes ) + else: + bias_scale = 1 + manual_out = jnp.dot(x, kernel_scale * kernel) + ( + bias_scale * bias + ).reshape(1, -1) + + self.assertTrue(jnp.allclose(out, manual_out)) + + @parameterized.parameters( + {'variable_filters': ({}, None, {'kernel', 'bias'}, {'Bar'}), + 'key_paths': {'Bar_0/Baz_0/Dense_0/kernel/scale', + 'Bar_0/Baz_0/Dense_0/bias/scale', + 'Bar_0/Dense_0/kernel/scale', + 'Bar_0/Dense_0/bias/scale', + 'Bar_0/Baz_1/Dense_0/kernel/scale', + 'Bar_0/Baz_1/Dense_0/bias/scale', + 'Bar_0/Dense_1/kernel/scale', + 'Bar_0/Dense_1/bias/scale'}}, + {'variable_filters': ({'kernel'},), + 'key_paths': {'Bar_0/Baz_0/Dense_0/kernel/scale', + 'Bar_0/Dense_0/kernel/scale', + 'Bar_0/Baz_1/Dense_0/kernel/scale', + 'Bar_0/Dense_1/kernel/scale'}}, + {'variable_filters': ({'Baz', 'kernel'},), + 'key_paths': {'Bar_0/Baz_0/Dense_0/kernel/scale', + 'Bar_0/Baz_0/Dense_0/bias/scale', + 'Bar_0/Dense_0/kernel/scale', + 'Bar_0/Baz_1/Dense_0/kernel/scale', + 'Bar_0/Baz_1/Dense_0/bias/scale', + 'Bar_0/Dense_1/kernel/scale'}} + ) + def test_weight_norm_variable_filter(self, variable_filters, key_paths): + class Baz(nn.Module): + + @nn.compact + def __call__(self, x): + return nn.Dense(2)(x) + + class Bar(nn.Module): + + @nn.compact + def __call__(self, x): + x = Baz()(x) + x = nn.Dense(3)(x) + x = Baz()(x) + x = nn.Dense(3)(x) + return x - def test_spectral_norm_3d_tensor(self): - for error_on_non_matrix in (True, False): + for variable_filter in variable_filters: class Foo(nn.Module): @nn.compact - def __call__(self, x, train): - x = nn.SpectralNorm( - nn.DenseGeneral((3, 4), use_bias=False), - error_on_non_matrix=error_on_non_matrix, - )(x, update_stats=train) - return x + def __call__(self, x): + return nn.WeightNorm(Bar(), variable_filter=variable_filter)(x) - x = jnp.ones((1, 2)) - model_cls = Foo() + v = Foo().init(jax.random.key(0), jnp.ones((1, 4))) + self.assertEqual(key_paths, v['params']['WeightNorm_0'].keys()) - if error_on_non_matrix: - with self.assertRaisesRegex( - ValueError, 'Input is 3D but error_on_non_matrix is True' - ): - variables = model_cls.init(random.PRNGKey(0), x, train=False) - else: - variables = model_cls.init(random.PRNGKey(0), x, train=False) + @parameterized.parameters( + {'model_index': 0, 'key_paths': {'Dense_1/kernel/scale'}}, + {'model_index': 1, 'key_paths': {'Conv_0/kernel/scale'}}, + { + 'model_index': 2, + 'key_paths': { + 'MultiHeadDotProductAttention_0/key/kernel/scale', + 'MultiHeadDotProductAttention_0/out/kernel/scale', + 'MultiHeadDotProductAttention_0/query/kernel/scale', + 'MultiHeadDotProductAttention_0/value/kernel/scale', + }, + }, + ) + def test_weight_norm_train(self, model_index, key_paths): + class FooDense(nn.Module): + + @nn.compact + def __call__( + self, + x, + ): + x = nn.Dense(8)(x) + x = nn.WeightNorm(nn.Dense(6))(x) + x = nn.Dense(4)(x) + return x + + class FooConv(nn.Module): + + @nn.compact + def __call__( + self, + x, + ): + x = nn.Dense(9)(x) + x = x.reshape((1, 3, 3)) + x = nn.WeightNorm(nn.Conv(2, kernel_size=(2, 2)))(x) + x = x.reshape(1, -1) + x = nn.Dense(4)(x) + return x + + class FooAttention(nn.Module): + + @nn.compact + def __call__(self, x): + a = nn.Dense(4)(x) + b = nn.Dense(4)(x) + x = nn.WeightNorm(nn.attention.MultiHeadDotProductAttention(4))(a, b) + x = nn.Dense(4)(x) + return x + + key1, key2, key3 = random.split(random.PRNGKey(0), 3) + x = random.normal(key1, (1, 4)) + y = random.normal(key2, (1, 4)) + + model_cls = (FooDense, FooConv, FooAttention)[model_index] + params = model_cls().init(key3, x)['params'] + self.assertEqual(key_paths, params['WeightNorm_0'].keys()) + + state = train_state.TrainState.create( + apply_fn=model_cls().apply, + params=params, + tx=optax.adam(1e-3), + ) + + @jax.jit + def train_step(state, batch): + def loss_fn(params): + logits = state.apply_fn( + {'params': params}, + x=batch['image'], + ) + loss = jnp.mean( + optax.l2_loss(predictions=logits, targets=batch['label']) + ) + return loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + return state, loss + + prev_loss = float('inf') + for _ in range(10): + state, loss = train_step(state, {'image': x, 'label': y}) + self.assertLess(loss, prev_loss) + prev_loss = loss class StochasticTest(absltest.TestCase): @@ -654,5 +870,133 @@ def test_hashable(self): self.assertNotEqual(hash(id1), hash(id1dc)) +class Fp8Test(absltest.TestCase): + + def test_fp8_dot_general_injection(self): + # Used to cast the inputs to be representable in FP8, so that the difference + # of the results from the original gemm and fp8 gemm is small. + cast_to_representable = functools.partial(fp8_ops.quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=jnp.float32) + + init_key, random_key = random.split(random.PRNGKey(seed=123), 2) + x = cast_to_representable( + random.uniform(random_key, (16, 32)), jnp.float8_e4m3fn) + dy = cast_to_representable( + random.uniform(random_key, (16, 64)), jnp.float8_e5m2) + + def run(fp8_injection, expected_shapes): + p = nn.DenseGeneral(features=64, name='dense') + if fp8_injection: + p.dot_general_cls=nn.Fp8DotGeneralOp + y, initial_vars = p.init_with_output(init_key, x) + var_shapes = jax.tree_util.tree_map(jnp.shape, initial_vars) + self.assertEqual(var_shapes, expected_shapes) + + def _train(variables, x): + y = p.apply(variables, x) + loss = y * dy + return jnp.mean(loss) + + train_fn = jax.jit(jax.value_and_grad(_train, argnums=[0, 1])) + outputs, grads = train_fn(initial_vars, x) + return outputs, grads + + expected_shapes_original = { + 'params': {'kernel': (32, 64), 'bias': (64,)}, + } + expected_shapes_new = { + 'params': {'kernel': (32, 64), 'bias': (64,)}, + fp8_ops.OVERWRITE_WITH_GRADIENT: { + 'Fp8DotGeneralOp_0': {'input_amax_history': (1024,), + 'kernel_amax_history': (1024,), + 'output_grad_amax_history': (1024,), + 'input_scale': (1,), + 'kernel_scale': (1,), + 'output_grad_scale': (1,), }}, + } + + output1a, output1b = run(False, expected_shapes_original) + output2a, output2b = run(True, expected_shapes_new) + dw1, dw2 = output1b[0]['params']['kernel'], output2b[0]['params']['kernel'] + dx1, dx2 = output1b[1], output2b[1] + + np.testing.assert_allclose(output1a, output2a, atol=1e-02) + np.testing.assert_allclose(dw1, dw2, atol=1e-04) + np.testing.assert_allclose(dx1, dx2, atol=1e-04) + + def test_fp8_train_state(self): + key, init_key, random_key = random.split(random.PRNGKey(seed=123), 3) + x = random.uniform(random_key, (16, 16), dtype=jnp.float32) + dense = nn.DenseGeneral(features=32, use_bias=True, + dot_general_cls=nn.Fp8DotGeneralOp) + variables = dense.init(init_key, x) + opt = optax.adam(learning_rate=.1) + state = train_state.TrainState.create( + params=variables, tx=opt, apply_fn=dense.apply + ) + + def _roll_and_update(amax_h, update): + return jnp.roll(amax_h, shift=-1, axis=0).at[0].set(update) + + def _train_loss(state, x, dy): + def loss_fn(vars): + y = state.apply_fn(vars, x) + loss = y * dy.astype(y.dtype) + return jnp.sum(loss) + grad_fn = jax.grad(loss_fn) + grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + return state + + train_fn = jax.jit(_train_loss) + + scale_x, amax_history_x = jnp.ones(()), jnp.zeros((1024,)) + scale_k, amax_history_k = jnp.ones(()), jnp.zeros((1024,)) + scale_g, amax_history_g = jnp.ones(()), jnp.zeros((1024,)) + e4m3_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) + e5m2_max = jnp.finfo(jnp.float8_e5m2).max.astype(jnp.float32) + + for _ in range(5): + key, random_key = random.split(key, 2) + x = random.normal(random_key, (16, 16), dtype=jnp.float32) + g = random.normal(random_key, (16, 32), dtype=jnp.float32) + k = state.params['params']['kernel'] + + # Manually compute the expected amax history and scaling factors. + amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x))) + amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(k))) + amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g))) + amax_from_history_x = jnp.max(amax_history_x, axis=0) + amax_from_history_k = jnp.max(amax_history_k, axis=0) + amax_from_history_g = jnp.max(amax_history_g, axis=0) + scale_x = fp8_ops.compute_scale(amax_from_history_x, scale_x, e4m3_max) + scale_k = fp8_ops.compute_scale(amax_from_history_k, scale_k, e4m3_max) + scale_g = fp8_ops.compute_scale(amax_from_history_g, scale_g, e5m2_max) + + state = train_fn(state, x, g) + + rtol, atol = 0.001, 0.001 + fp8_vars = ( + state.params[fp8_ops.OVERWRITE_WITH_GRADIENT]['Fp8DotGeneralOp_0'] + ) + np.testing.assert_allclose( + fp8_vars['input_amax_history'], amax_history_x, rtol=rtol, atol=atol, + ) + np.testing.assert_allclose( + fp8_vars['kernel_amax_history'], amax_history_k, rtol=rtol, atol=atol, + ) + np.testing.assert_allclose( + fp8_vars['output_grad_amax_history'], + amax_history_g, + rtol=rtol, + atol=atol, + ) + + np.testing.assert_allclose(fp8_vars['input_scale'][0], scale_x) + np.testing.assert_allclose(fp8_vars['kernel_scale'][0], scale_k) + np.testing.assert_allclose(fp8_vars['output_grad_scale'][0], scale_g) + + if __name__ == '__main__': absltest.main() diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 6ef18b140b..a337444603 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -15,7 +15,7 @@ """Transforms tests.""" from functools import partial -from typing import Any, Tuple, Iterable, Callable, Sequence +from typing import Any, Callable, Sequence import operator import unittest @@ -24,11 +24,9 @@ from jax import random import jax.numpy as jnp import numpy as np -from flax import config from flax import errors from flax import linen as nn from flax.core import freeze, copy -from flax.configurations import temp_flip_flag # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() diff --git a/tests/linen/toplevel_test.py b/tests/linen/toplevel_test.py deleted file mode 100644 index caf6c29038..0000000000 --- a/tests/linen/toplevel_test.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -import jax -from jax import random -from jax.nn import initializers -import jax.numpy as jnp - -import numpy as np -from typing import Any, Tuple - -from flax import linen as nn -from flax.core import Scope - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - - -class Dummy(nn.Module): - - @nn.compact - def __call__(self): - self.param('foo', lambda rng: 1) - - -class ModuleTopLevelTest(absltest.TestCase): - pass - # def test_toplevel_immutable(self): - # d = Dummy(parent=None) - # with self.assertRaisesRegex(BaseException, "orphaned module"): - # d() - - # def test_toplevel_initialized_requires_rng(self): - # with self.assertRaisesRegex(BaseException, "missing 1 required.*rngs"): - # d = Dummy(parent=None).initialized() - - # def test_toplevel_initialized_with_rng(self): - # d = Dummy(parent=None).initialized(rngs={'params': random.key(0)}) - # self.assertEqual(d.variables.param.foo, 1) - - # def test_toplevel_initialized_frozen(self): - # d = Dummy(parent=None).initialized(rngs={'params': random.key(0)}) - # with self.assertRaisesRegex(BaseException, "Can't set value"): - # d.variables.param.foo = 2 - - # def test_toplevel_initialized_has_new_scope(self): - # d = Dummy(parent=None) - # # initializing should make a copy and not have any effect - # # on `d` itself. - # d_initialized = d.initialized(rngs={'params': random.key(0)}) - # # ... make sure that indeed `d` has no scope. - # self.assertIsNone(d.scope) - - # def test_can_only_call_initialized_once(self): - # d = Dummy(parent=None) - # d = d.initialized(rngs={'params': random.key(0)}) - # with self.assertRaises(BaseException): - # d.initialized(rngs={'params': random.key(0)}) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/struct_test.py b/tests/struct_test.py index 122afd5d3d..32926aea99 100644 --- a/tests/struct_test.py +++ b/tests/struct_test.py @@ -15,7 +15,6 @@ """Tests for flax.struct.""" from typing import Any -import unittest from absl.testing import absltest