Skip to content

Commit

Permalink
Merge branch 'main' into nnx
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 26, 2023
2 parents c9bc86b + 738078c commit e313797
Show file tree
Hide file tree
Showing 72 changed files with 1,303 additions and 857 deletions.
1 change: 0 additions & 1 deletion .github/analytics/get_repo_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Callable, List
from absl import app, flags

Expand Down
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ repos:
--extra-keys,
"metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab",
]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
hooks:
- id: ruff
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ sphinx:
formats:
- htmlzip
- epub
- pdf
# - pdf

# Optionally set the version of Python and requirements required to build your docs
python:
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ vNext
-
-
-
-
- Re-factored `MultiHeadDotProductAttention`'s call method signatur, by adding
`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic`
to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389).
-
-
-
Expand Down
1 change: 0 additions & 1 deletion docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
In order to highlight a line of code, append "#!" to it.
"""
import itertools
from typing import List, Tuple

from docutils import nodes
Expand Down
2 changes: 1 addition & 1 deletion docs/api_reference/flax.cursor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ To illustrate, consider the example below::
import dataclasses
from typing import Any

@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class A:
x: Any

Expand Down
9 changes: 9 additions & 0 deletions docs/api_reference/flax.linen/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,18 @@ Normalization
:module: flax.linen
:class: GroupNorm

.. flax_module::
:module: flax.linen
:class: RMSNorm

.. flax_module::
:module: flax.linen
:class: SpectralNorm

.. flax_module::
:module: flax.linen
:class: WeightNorm


Combinators
------------------------
Expand Down Expand Up @@ -132,6 +140,7 @@ Recurrent
GroupNorm
RMSNorm
SpectralNorm
WeightNorm
Sequential
Dropout
SelfAttention
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
sys.path.append(os.path.abspath('./_ext'))

# patch sphinx
import docs.conf_sphinx_patch
# -- Project information -----------------------------------------------------

project = 'Flax'
Expand Down Expand Up @@ -134,6 +133,7 @@
nb_execution_excludepatterns = [
'getting_started.ipynb', # <-- times out
'optax_update_guide.ipynb', # <-- requires flax<=0.5.3
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
]
# raise exceptions on execution so CI can catch errors
nb_execution_allow_errors = False
Expand Down
4 changes: 0 additions & 4 deletions docs/examples_community_examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ Examples
- `@vasudevgupta7 <https://github.com/vasudevgupta7>`__
- Question-Answering
- https://arxiv.org/abs/2007.14062
* - `Bayesian Networks with BlackJAX <https://blackjax-devs.github.io/blackjax/examples/SGMCMC.html>`__
- `@rlouf <https://github.com/rlouf>`__
- Bayesian Inference, SGMCMC
- https://arxiv.org/abs/1402.4102
* - `DCGAN <https://github.com/bkkaggle/jax-dcgan>`__
- `@bkkaggle <https://github.com/bkkaggle>`__
- Image Synthesis
Expand Down
38 changes: 38 additions & 0 deletions docs/faq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Frequently Asked Questions (FAQ)
================================

This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in `GitHub Discussions <https://github.com/google/flax/discussions>`__.

Where to search for an answer to a Flax-related question?
*********************************************************

There are a number of official Flax resources to search for information:

- `Flax Documentation on ReadTheDocs <https://flax.readthedocs.io/en/latest/>`__ (this site): Use the `search bar <https://flax.readthedocs.io/en/search.html>`__ or the table of contents on the left-hand side.
- `google/flax GitHub Discussions <https://github.com/google/flax/discussions>`__: Search for an existing topic or start a new one. If you can't find what you're looking for, feel free to ask the Flax team or community a question.
- `google/flax GitHub Issues <https://github.com/google/flax/issues>`__: Use the search bar to look for an existing issue or a feature request, or start a new one.

How to take the derivative with respect to an intermediate value (using :code:`Module.perturb`)?
************************************************************************************************

To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use :meth:`flax.linen.Module.perturb`. You define a zero-value :class:`flax.linen.Module` "perturbation" parameter – :code:`perturb(...)` – in the forward pass with the same shape as the intermediate activation, define the loss function with :code:`'perturbations'` as an added standalone argument, perform a JAX derivative operation with :code:`jax.grad` on the perturbation argument.

For full examples and detailed documentation, go to:

- The :meth:`flax.linen.Module.perturb` API docs
- The `Extracting gradients of intermediate values <https://flax.readthedocs.io/en/latest/guides/extracting_intermediates.html#extracting-gradients-of-intermediate-values>`_ guide
- `Flax GitHub Discussions #1152 <https://github.com/google/flax/discussions/1152>`__

Is Flax Linen :code:`remat_scan()` the same as :code:`scan(remat(...))`?
************************************************************************

Flax :code:`remat_scan()` (:meth:`flax.linen.remat_scan()`) and :code:`scan(remat(...))` (:meth:`flax.linen.scan` over :meth:`flax.linen.remat`) are not the same, and :code:`remat_scan()` is limited in cases it supports. Namely, :code:`remat_scan()` treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use :code:`scan(remat(...))`, as typically you would need the extra parameters, such as ``in_axes`` (for input array axes) or ``out_axes`` (output array axes), which :meth:`flax.linen.remat_scan` does not expose.

What are the recommended training loop libraries?
*************************************************

Consider using CLU (Common Loop Utils) `google/CommonLoopUtils <https://github.com/google/CommonLoopUtils>`__. To get started, go to this `CLU Synopsis Colab <https://colab.research.google.com/github/google/CommonLoopUtils/blob/main/clu_synopsis.ipynb>`__. You can find answers to common questions about CLU with Flax on `google/flax GitHub Discussions <https://github.com/google/flax/discussions?discussions_q=clu>`__.

Check out the official `google/flax Examples <https://github.com/google/flax/tree/main/examples>`__ for examples of using the training loop with (CLU) metrics. For example, this is `Flax ImageNet's train.py <https://github.com/google/flax/blob/main/examples/imagenet/train.py>`__.

For computer vision research, consider `google-research/scenic <https://github.com/google-research/scenic>`__. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the `README page on GitHub <https://github.com/google-research/scenic#getting-started>`__.
36 changes: 35 additions & 1 deletion docs/guides/convert_pytorch_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Convolutions and FC Layers
We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc).
In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then
reshaped to [N, C * H * W] before being fed to the fc layers.
When we port our weights from PyToch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax.
When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax.
Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W].

Consider this PyTorch model:
Expand Down Expand Up @@ -266,6 +266,40 @@ while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolut
implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_
that contains an implementation.

To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the ``transpose_kernel`` arg in Flax's
``nn.ConvTranspose`` layer.

.. testcode::

# padding is inverted
torch_padding = 0
flax_padding = 1 - torch_padding

t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=torch_padding)

kernel = t_conv.weight.detach().cpu().numpy()
bias = t_conv.bias.detach().cpu().numpy()

# [inC, outC, kH, kW] -> [kH, kW, outC, inC]
kernel = jnp.transpose(kernel, (2, 3, 1, 0))

key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

variables = {'params': {'kernel': kernel, 'bias': bias}}
# ConvTranspose expects the kernel to be [kH, kW, inC, outC],
# but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead
j_conv = nn.ConvTranspose(features=4, kernel_size=(2, 2), padding=flax_padding, transpose_kernel=True)
j_out = j_conv.apply(variables, x)

# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_conv(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
np.testing.assert_almost_equal(j_out, t_out, decimal=6)


.. _`pull request`: https://github.com/google/jax/pull/5772

.. |nn.ConvTranspose| replace:: ``nn.ConvTranspose``
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/flax_on_pjit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
"\n",
"1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters.\n",
"\n",
"2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n",
"2. Apply [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n",
"\n",
" * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless."
]
Expand Down Expand Up @@ -1281,7 +1281,7 @@
"\n",
"* **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming.\n",
"\n",
"* **logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n",
"* **Logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n",
"\n",
"* **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful."
]
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/flax_on_pjit.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ To shard the parameters efficiently, apply the following APIs to annotate the pa

1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters.

2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.
2. Apply [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.

* This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless.

Expand Down Expand Up @@ -551,7 +551,7 @@ Choosing when to use a device or logical axis depends on how much you want to co

* **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming.

* **logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.
* **Logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.

* **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful.

Expand Down
10 changes: 6 additions & 4 deletions docs/guides/regular_dict_upgrade_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ Alternatively, the environment variable ``flax_return_frozendict``
(found `here <https://github.com/google/flax/blob/main/flax/configurations.py>`__) can be directly modified in the Flax source code.


Migration plan
Migration status
--------------

Currently ``flax_return_frozendict`` is set to True, meaning Flax will default to returning ``FrozenDicts``.
In the future this flag will be flipped to False, and Flax will instead default to returning regular dicts.
Eventually this feature flag will be removed once the migration is complete.
As of July 19th, 2023, ``flax_return_frozendict`` is set to ``False`` (see
`#3193 <https://github.com/google/flax/pull/3193>`__), meaning Flax will default to
returning regular dicts from version `0.7.1 <https://github.com/google/flax/releases/tag/v0.7.1>`__
onward. This flag can be flipped to ``True`` temporarily to have Flax return
``Frozendicts``. However this feature flag will eventually be removed in the future.
14 changes: 5 additions & 9 deletions docs/guides/transfer_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
"metadata": {},
"source": [
"## Transfering the parameters\n",
"Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location. This can be done by unfreezing `params`, updating the `backbone` parameters, and freezing the `params` again:"
"Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location (i.e. the `backbone`):"
]
},
{
Expand All @@ -174,11 +174,7 @@
"metadata": {},
"outputs": [],
"source": [
"import flax\n",
"\n",
"params = flax.core.unfreeze(params)\n",
"params['backbone'] = vision_model_vars['params']\n",
"params = flax.core.freeze(params)"
"params['backbone'] = vision_model_vars['params']"
]
},
{
Expand Down Expand Up @@ -247,13 +243,13 @@
"import optax\n",
"\n",
"partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}\n",
"param_partitions = flax.core.freeze(traverse_util.path_aware_map(\n",
" lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params))\n",
"param_partitions = traverse_util.path_aware_map(\n",
" lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)\n",
"tx = optax.multi_transform(partition_optimizers, param_partitions)\n",
"\n",
"# visualize a subset of the param_partitions structure\n",
"flat = list(traverse_util.flatten_dict(param_partitions).items())\n",
"flax.core.freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))"
"traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))"
]
},
{
Expand Down
12 changes: 4 additions & 8 deletions docs/guides/transfer_learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,10 @@ params = variables['params']
```

## Transfering the parameters
Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location. This can be done by unfreezing `params`, updating the `backbone` parameters, and freezing the `params` again:
Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location (i.e. the `backbone`):

```{code-cell} ipython3
import flax
params = flax.core.unfreeze(params)
params['backbone'] = vision_model_vars['params']
params = flax.core.freeze(params)
```

**Note:** if the model contains other variable collections such as `batch_stats`, these have to be transfered as well.
Expand Down Expand Up @@ -153,13 +149,13 @@ from flax import traverse_util
import optax
partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}
param_partitions = flax.core.freeze(traverse_util.path_aware_map(
lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params))
param_partitions = traverse_util.path_aware_map(
lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)
tx = optax.multi_transform(partition_optimizers, param_partitions)
# visualize a subset of the param_partitions structure
flat = list(traverse_util.flatten_dict(param_partitions).items())
flax.core.freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))
traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))
```

To implement [differential learning rates](https://blog.slavv.com/differential-learning-rates-59eff5209a4f), the `optax.set_to_zero` can be replaced with any other optimizer, different optimizers and partitioning schemes can be selected depending on the task. For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation.
Expand Down
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,9 @@ Notable examples in Flax include:
guides/index
examples
glossary
faq
developer_notes/index
philosophy
contributing
experimental
api_reference/index
api_reference/index
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ tensorflow_text>=2.11.0 # WMT example

# notebooks
einops
transformers[flax]
1 change: 0 additions & 1 deletion examples/imagenet/imagenet_fake_data_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import time

from absl.testing import absltest
from absl.testing.flagsaver import flagsaver
from flax.testing import Benchmark
import jax
import tensorflow_datasets as tfds
Expand Down
1 change: 0 additions & 1 deletion examples/imagenet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import logging
from clu import metric_writers
from clu import periodic_actions
import flax
from flax import jax_utils
from flax.training import checkpoints
from flax.training import common_utils
Expand Down
6 changes: 2 additions & 4 deletions examples/linen_design_test/attention_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@

import functools
from pprint import pprint
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union
from flax.core import Scope
from flax.core.frozen_dict import freeze, unfreeze
from typing import Any, Callable, Optional, Sequence
from flax.core.frozen_dict import unfreeze
from flax.linen import initializers
from flax.linen import Module, compact, vmap
from flax.linen.linear import PrecisionLike
import jax
from jax import lax, numpy as jnp, random
import numpy as np


class Dense(Module):
Expand Down
Loading

0 comments on commit e313797

Please sign in to comment.