Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed broken links #3161

Merged
merged 1 commit into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/developer_notes/lift.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ Please open a GitHub issue when you find a use case that is not supported yet by
| xmap | ❌ | |

References:
- [Linen transforms documentation](https://flax.readthedocs.io/en/latest/flax.linen.html#module-flax.linen.transforms).
- [Linen transforms documentation](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html)
- [Linen transforms source code](https://github.com/google/flax/blob/main/flax/linen/transforms.py)
- [Core lifting source code](https://github.com/google/flax/blob/main/flax/core/lift.py)

Expand Down
10 changes: 5 additions & 5 deletions docs/developer_notes/module_lifecycle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The Flax Module lifecycle

This design note is intended for users who are already familiar with Flax Linen Modules but want to understand more about the design principles behind the abstraction. This note should give you a good understanding of the assumptions and guarantees the Module API is built upon. If you have no practical experience with Modules yet, check out the `Getting started notebook <https://flax.readthedocs.io/en/latest/getting_started.html>`_.

Flax Linen Modules offer a Pythonic abstraction on top of Flax core. The `Module <https://flax.readthedocs.io/en/latest/flax.linen.html#module>`_ abstraction allows you to create classes that have state, parameters and randomness on top of JAX. This is a practical guide to the design and behavior of the ``Module`` class. By the end, you should feel comfortable to go off the beaten track and use Modules in new ways.
Flax Linen Modules offer a Pythonic abstraction on top of Flax core. The `Module <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html>`_ abstraction allows you to create classes that have state, parameters and randomness on top of JAX. This is a practical guide to the design and behavior of the ``Module`` class. By the end, you should feel comfortable to go off the beaten track and use Modules in new ways.


Overview
Expand Down Expand Up @@ -102,7 +102,7 @@ Notice that the lifecycle includes cloning the Module instance. This is done to
Variables
==========

The word “variable” is ubiquitous in programming and math. However, it's important to have a good understanding of what variables are in the context of JAX and Flax. Inside Flax Modules, `variables <https://flax.readthedocs.io/en/latest/flax.linen.html#module-flax.core.variables>`_ act like you expect from Python. They are initialized once, read, and perhaps even updated every so often. However, JAX has no concept of variables. Instead, values are stored in arrays similar to NumPy arrays - with one important difference: they are immutable.
The word “variable” is ubiquitous in programming and math. However, it's important to have a good understanding of what variables are in the context of JAX and Flax. Inside Flax Modules, `variables <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html>`_ act like you expect from Python. They are initialized once, read, and perhaps even updated every so often. However, JAX has no concept of variables. Instead, values are stored in arrays similar to NumPy arrays - with one important difference: they are immutable.

The ``init`` and ``apply`` methods return the variables as a nested dictionary with string keys and JAX arrays at the leaves. At the top level each key corresponds to a variable collection. Inside each collection the nested dict structure corresponds with the ``Module`` hierarchy. The variable dict is immutable and therefore really just a snapshot of state the variables are in. When ``apply`` is called again, the variable dict is passed as an argument. Such that the variables are in the same state as when the previous ``init`` / ``apply`` call finished.

Expand Down Expand Up @@ -242,7 +242,7 @@ To make this approach work reliably we need well-defined cloning behavior. Rathe
Bind
===============================================

Sometimes it's useful to have a bound, top-level Module without having to wrap the code in a function. For example: to interact with a Module inside a Jupyter notebook. The `bind <https://flax.readthedocs.io/en/latest/flax.linen.html?highlight=bind#flax.linen.Module.bind>`_ method returns a bound clone with an unlimited lifetime. The downside of this is that you cannot combine it with JAX transformations or integrate it into a vanilla JAX codebase that expects stateless code. For example, `Optax <https://github.com/deepmind/optax>`_ can optimze a Pytree of parameters but it cannot directly optimize a bound ``Module`` instance created with ``.bind`` (because that's not a Pytree). Thus, you cannot combine the ``bind`` API with a functional optimizer API like Optax.
Sometimes it's useful to have a bound, top-level Module without having to wrap the code in a function. For example: to interact with a Module inside a Jupyter notebook. The `bind <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.bind>`_ method returns a bound clone with an unlimited lifetime. The downside of this is that you cannot combine it with JAX transformations or integrate it into a vanilla JAX codebase that expects stateless code. For example, `Optax <https://github.com/deepmind/optax>`_ can optimize a Pytree of parameters but it cannot directly optimize a bound ``Module`` instance created with ``.bind`` (because that's not a Pytree). Thus, you cannot combine the ``bind`` API with a functional optimizer API like Optax.


Setup
Expand All @@ -262,7 +262,7 @@ The ``setup`` method is often used like the constructor hook (``__init__``) in n
mdl = TopLevelAccess()
assert not hasattr(mdl, "foo") # foo is not defined because setup is not called

The ``setup`` method is not called immediately after the ``Module`` becomes bound but only when you interact with the ``Module`` instance (e.g.: call a method or access an attribute). This should not impact the behavior of a ``Module`` but the lazy execution does sometimes affect log statements and stack traces during debugging. The section on functionalization will explain why we need ``setup`` to be lazy in the first place.
The ``setup`` method is not called immediately after the ``Module`` becomes bound but only when you interact with the ``Module`` instance (e.g.: call a method or access an attribute). This should not impact the behavior of a ``Module`` but the lazy execution does sometimes affect log statements and stack traces during debugging. The section on :ref:`Functionalization` will explain why we need ``setup`` to be lazy in the first place.


Functionalization
Expand Down Expand Up @@ -314,7 +314,7 @@ For the most part functionalization is something that is handled automatically f

Here ``inner`` takes a function that closes over a Module instance. In this example, that works fine because we are not transforming the inner method with a lifted transformation. Most methods are not transformed but it is good to know how to make Module methods transformable.

The main obstacle for transformability are types that JAX does not recognize. JAX only understands `Pytree <https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html>`_ arguments. That's arbitrarily nested Python containers (dict, list, tuple) of (Jax) numpy ndarrays and Python numbers/bools. Flax allows to define dataclasses which are Pytree compatible using the `flax.struct <https://flax.readthedocs.io/en/latest/flax.struct.html>`_ API.
The main obstacle for transformability are types that JAX does not recognize. JAX only understands `Pytree <https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html>`_ arguments; i.e. arbitrarily nested Python containers (dict, list, tuple) of (Jax) numpy ndarrays and Python numbers/bools. Flax allows to define dataclasses which are Pytree compatible using the `flax.struct <https://flax.readthedocs.io/en/latest/flax.struct.html>`_ API.

Function closure is the most common way to accidentally hide a JAX array or Linen Module from a transformation. There is however an easy workaround if you want to pass closures that are also compatible with JAX and Linen transformations:

Expand Down
8 changes: 4 additions & 4 deletions docs/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@
"## 3. Define network\n",
"\n",
"Create a convolutional neural network with the Linen API by subclassing\n",
"[Flax Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).\n",
"[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n",
"Because the architecture in this example is relatively simple—you're just\n",
"stacking layers—you can define the inlined submodules directly within the\n",
"`__call__` method and wrap it with the\n",
"[`@compact`](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)\n",
"[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact)\n",
"decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide."
]
},
Expand Down Expand Up @@ -156,7 +156,7 @@
"source": [
"### View model layers\n",
"\n",
"Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input."
"Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input."
]
},
{
Expand Down Expand Up @@ -360,7 +360,7 @@
"A function that:\n",
"\n",
"- Evaluates the neural network given the parameters and a batch of input images\n",
" with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)\n",
" with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply)\n",
" method (forward pass)).\n",
"- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding.\n",
"- Evaluates the gradient of the loss function using\n",
Expand Down
8 changes: 4 additions & 4 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def get_datasets(num_epochs, batch_size):
## 3. Define network

Create a convolutional neural network with the Linen API by subclassing
[Flax Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).
[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).
Because the architecture in this example is relatively simple—you're just
stacking layers—you can define the inlined submodules directly within the
`__call__` method and wrap it with the
[`@compact`](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)
[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact)
decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide.

```{code-cell}
Expand Down Expand Up @@ -116,7 +116,7 @@ class CNN(nn.Module):

### View model layers

Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input.
Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input.

```{code-cell}
---
Expand Down Expand Up @@ -221,7 +221,7 @@ def create_train_state(module, rng, learning_rate, momentum):
A function that:

- Evaluates the neural network given the parameters and a batch of input images
with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)
with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply)
method (forward pass)).
- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding.
- Evaluates the gradient of the loss function using
Expand Down
4 changes: 2 additions & 2 deletions docs/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ For additional terms, refer to the `Jax glossary <https://jax.readthedocs.io/en/
Refer to :class:`flax.training.train_state.TrainState`.

Variable
The `weights / parameters / data / arrays <https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.core.variables.Variable>`__
The `weights / parameters / data / arrays <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html#flax.linen.Variable>`__
residing in the leaves of :term:`variable collections<Variable collections>`.
Variables are defined inside modules using :meth:`Module.variable() <flax.linen.Module.variable>`.
A variable of collection "params" is simply called a param and can be set using
Expand All @@ -100,7 +100,7 @@ For additional terms, refer to the `Jax glossary <https://jax.readthedocs.io/en/
They are typically differentiable, updated by an outer SGD-like loop / optimizer,
rather than modified directly by forward-pass code.

`Variable dictionary <https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module-flax.core.variables>`__
`Variable dictionary <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html>`__
A dictionary containing :term:`variable collections<Variable collections>`.
Each variable collection is a mapping from a string name
(e.g., ":term:`params<Params / parameters>`" or "batch_stats") to a (possibly nested)
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/batch_norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ The ``batch_stats`` collection
In addition to the ``params`` collection, ``BatchNorm`` also adds a ``batch_stats`` collection
that contains the running average of the batch statistics.

Note: You can learn more in the ``flax.linen`` `variables <https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module-flax.core.variables>`__
Note: You can learn more in the ``flax.linen`` `variables <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html>`__
API documentation.

The ``batch_stats`` collection must be extracted from the ``variables`` for later use.
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/dropout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ To create a model with dropout:
* Subclass :meth:`flax.linen.Module`, and then use
:meth:`flax.linen.Dropout` to add a dropout layer. Recall that
:meth:`flax.linen.Module` is the
`base class for all neural network Modules <https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module>`__,
`base class for all neural network Modules <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html>`__,
and all layers and models are subclassed from it.

* In :meth:`flax.linen.Dropout`, the ``deterministic`` argument is required to
Expand Down Expand Up @@ -130,7 +130,7 @@ After creating your model:
* Instantiate the model.
* Then, in the :meth:`flax.linen.init()` call, set ``training=False``.
* Finally, extract the ``params`` from the
`variable dictionary <https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module-flax.core.variables>`__.
`variable dictionary <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html>`__.

Here, the main difference between the code without Flax ``Dropout``
and with ``Dropout`` is that the ``training`` (or ``train``) argument must be
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/ensembling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,6 @@ directly.
.. _jax.pmap(): https://jax.readthedocs.io/en/latest/jax.html#jax.pmap
.. |jax.lax.pmean()| replace:: ``jax.lax.pmean()``
.. _jax.lax.pmean(): https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.pmean.html
.. _Module.init: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.init
.. _Module.init: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init
.. _`JIT mechanics: tracing and static variables`: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#JIT-mechanics:-tracing-and-static-variables
.. _`MNIST example`: https://github.com/google/flax/blob/main/examples/mnist/train.py
Loading