diff --git a/docs/developer_notes/lift.md b/docs/developer_notes/lift.md index f949e9eb3c..cc32251f0a 100644 --- a/docs/developer_notes/lift.md +++ b/docs/developer_notes/lift.md @@ -295,7 +295,7 @@ Please open a GitHub issue when you find a use case that is not supported yet by | xmap | ❌ | | References: -- [Linen transforms documentation](https://flax.readthedocs.io/en/latest/flax.linen.html#module-flax.linen.transforms). +- [Linen transforms documentation](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) - [Linen transforms source code](https://github.com/google/flax/blob/main/flax/linen/transforms.py) - [Core lifting source code](https://github.com/google/flax/blob/main/flax/core/lift.py) diff --git a/docs/developer_notes/module_lifecycle.rst b/docs/developer_notes/module_lifecycle.rst index 799c58fa46..8a28f8a38e 100644 --- a/docs/developer_notes/module_lifecycle.rst +++ b/docs/developer_notes/module_lifecycle.rst @@ -12,7 +12,7 @@ The Flax Module lifecycle This design note is intended for users who are already familiar with Flax Linen Modules but want to understand more about the design principles behind the abstraction. This note should give you a good understanding of the assumptions and guarantees the Module API is built upon. If you have no practical experience with Modules yet, check out the `Getting started notebook `_. -Flax Linen Modules offer a Pythonic abstraction on top of Flax core. The `Module `_ abstraction allows you to create classes that have state, parameters and randomness on top of JAX. This is a practical guide to the design and behavior of the ``Module`` class. By the end, you should feel comfortable to go off the beaten track and use Modules in new ways. +Flax Linen Modules offer a Pythonic abstraction on top of Flax core. The `Module `_ abstraction allows you to create classes that have state, parameters and randomness on top of JAX. This is a practical guide to the design and behavior of the ``Module`` class. By the end, you should feel comfortable to go off the beaten track and use Modules in new ways. Overview @@ -102,7 +102,7 @@ Notice that the lifecycle includes cloning the Module instance. This is done to Variables ========== -The word “variable” is ubiquitous in programming and math. However, it's important to have a good understanding of what variables are in the context of JAX and Flax. Inside Flax Modules, `variables `_ act like you expect from Python. They are initialized once, read, and perhaps even updated every so often. However, JAX has no concept of variables. Instead, values are stored in arrays similar to NumPy arrays - with one important difference: they are immutable. +The word “variable” is ubiquitous in programming and math. However, it's important to have a good understanding of what variables are in the context of JAX and Flax. Inside Flax Modules, `variables `_ act like you expect from Python. They are initialized once, read, and perhaps even updated every so often. However, JAX has no concept of variables. Instead, values are stored in arrays similar to NumPy arrays - with one important difference: they are immutable. The ``init`` and ``apply`` methods return the variables as a nested dictionary with string keys and JAX arrays at the leaves. At the top level each key corresponds to a variable collection. Inside each collection the nested dict structure corresponds with the ``Module`` hierarchy. The variable dict is immutable and therefore really just a snapshot of state the variables are in. When ``apply`` is called again, the variable dict is passed as an argument. Such that the variables are in the same state as when the previous ``init`` / ``apply`` call finished. @@ -242,7 +242,7 @@ To make this approach work reliably we need well-defined cloning behavior. Rathe Bind =============================================== -Sometimes it's useful to have a bound, top-level Module without having to wrap the code in a function. For example: to interact with a Module inside a Jupyter notebook. The `bind `_ method returns a bound clone with an unlimited lifetime. The downside of this is that you cannot combine it with JAX transformations or integrate it into a vanilla JAX codebase that expects stateless code. For example, `Optax `_ can optimze a Pytree of parameters but it cannot directly optimize a bound ``Module`` instance created with ``.bind`` (because that's not a Pytree). Thus, you cannot combine the ``bind`` API with a functional optimizer API like Optax. +Sometimes it's useful to have a bound, top-level Module without having to wrap the code in a function. For example: to interact with a Module inside a Jupyter notebook. The `bind `_ method returns a bound clone with an unlimited lifetime. The downside of this is that you cannot combine it with JAX transformations or integrate it into a vanilla JAX codebase that expects stateless code. For example, `Optax `_ can optimize a Pytree of parameters but it cannot directly optimize a bound ``Module`` instance created with ``.bind`` (because that's not a Pytree). Thus, you cannot combine the ``bind`` API with a functional optimizer API like Optax. Setup @@ -262,7 +262,7 @@ The ``setup`` method is often used like the constructor hook (``__init__``) in n mdl = TopLevelAccess() assert not hasattr(mdl, "foo") # foo is not defined because setup is not called -The ``setup`` method is not called immediately after the ``Module`` becomes bound but only when you interact with the ``Module`` instance (e.g.: call a method or access an attribute). This should not impact the behavior of a ``Module`` but the lazy execution does sometimes affect log statements and stack traces during debugging. The section on functionalization will explain why we need ``setup`` to be lazy in the first place. +The ``setup`` method is not called immediately after the ``Module`` becomes bound but only when you interact with the ``Module`` instance (e.g.: call a method or access an attribute). This should not impact the behavior of a ``Module`` but the lazy execution does sometimes affect log statements and stack traces during debugging. The section on :ref:`Functionalization` will explain why we need ``setup`` to be lazy in the first place. Functionalization @@ -314,7 +314,7 @@ For the most part functionalization is something that is handled automatically f Here ``inner`` takes a function that closes over a Module instance. In this example, that works fine because we are not transforming the inner method with a lifted transformation. Most methods are not transformed but it is good to know how to make Module methods transformable. -The main obstacle for transformability are types that JAX does not recognize. JAX only understands `Pytree `_ arguments. That's arbitrarily nested Python containers (dict, list, tuple) of (Jax) numpy ndarrays and Python numbers/bools. Flax allows to define dataclasses which are Pytree compatible using the `flax.struct `_ API. +The main obstacle for transformability are types that JAX does not recognize. JAX only understands `Pytree `_ arguments; i.e. arbitrarily nested Python containers (dict, list, tuple) of (Jax) numpy ndarrays and Python numbers/bools. Flax allows to define dataclasses which are Pytree compatible using the `flax.struct `_ API. Function closure is the most common way to accidentally hide a JAX array or Linen Module from a transformation. There is however an easy workaround if you want to pass closures that are also compatible with JAX and Linen transformations: diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index aeb08ca641..0c4e91c2af 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -105,11 +105,11 @@ "## 3. Define network\n", "\n", "Create a convolutional neural network with the Linen API by subclassing\n", - "[Flax Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).\n", + "[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n", "Because the architecture in this example is relatively simple—you're just\n", "stacking layers—you can define the inlined submodules directly within the\n", "`__call__` method and wrap it with the\n", - "[`@compact`](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)\n", + "[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact)\n", "decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide." ] }, @@ -156,7 +156,7 @@ "source": [ "### View model layers\n", "\n", - "Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input." + "Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input." ] }, { @@ -360,7 +360,7 @@ "A function that:\n", "\n", "- Evaluates the neural network given the parameters and a batch of input images\n", - " with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)\n", + " with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply)\n", " method (forward pass)).\n", "- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding.\n", "- Evaluates the gradient of the loss function using\n", diff --git a/docs/getting_started.md b/docs/getting_started.md index 32dab82e5b..fec311bced 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -77,11 +77,11 @@ def get_datasets(num_epochs, batch_size): ## 3. Define network Create a convolutional neural network with the Linen API by subclassing -[Flax Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction). +[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html). Because the architecture in this example is relatively simple—you're just stacking layers—you can define the inlined submodules directly within the `__call__` method and wrap it with the -[`@compact`](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods) +[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact) decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide. ```{code-cell} @@ -116,7 +116,7 @@ class CNN(nn.Module): ### View model layers -Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. +Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. ```{code-cell} --- @@ -221,7 +221,7 @@ def create_train_state(module, rng, learning_rate, momentum): A function that: - Evaluates the neural network given the parameters and a batch of input images - with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply) + with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) method (forward pass)). - Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding. - Evaluates the gradient of the loss function using diff --git a/docs/glossary.rst b/docs/glossary.rst index 7b76422baa..2cfc2052cc 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -88,7 +88,7 @@ For additional terms, refer to the `Jax glossary `__ + The `weights / parameters / data / arrays `__ residing in the leaves of :term:`variable collections`. Variables are defined inside modules using :meth:`Module.variable() `. A variable of collection "params" is simply called a param and can be set using @@ -100,7 +100,7 @@ For additional terms, refer to the `Jax glossary `__ + `Variable dictionary `__ A dictionary containing :term:`variable collections`. Each variable collection is a mapping from a string name (e.g., ":term:`params`" or "batch_stats") to a (possibly nested) diff --git a/docs/guides/batch_norm.rst b/docs/guides/batch_norm.rst index 5b0cf0c865..c1c2c18f9e 100644 --- a/docs/guides/batch_norm.rst +++ b/docs/guides/batch_norm.rst @@ -69,7 +69,7 @@ The ``batch_stats`` collection In addition to the ``params`` collection, ``BatchNorm`` also adds a ``batch_stats`` collection that contains the running average of the batch statistics. -Note: You can learn more in the ``flax.linen`` `variables `__ +Note: You can learn more in the ``flax.linen`` `variables `__ API documentation. The ``batch_stats`` collection must be extracted from the ``variables`` for later use. diff --git a/docs/guides/dropout.rst b/docs/guides/dropout.rst index 8d01769318..efe7bc88ab 100644 --- a/docs/guides/dropout.rst +++ b/docs/guides/dropout.rst @@ -60,7 +60,7 @@ To create a model with dropout: * Subclass :meth:`flax.linen.Module`, and then use :meth:`flax.linen.Dropout` to add a dropout layer. Recall that :meth:`flax.linen.Module` is the - `base class for all neural network Modules `__, + `base class for all neural network Modules `__, and all layers and models are subclassed from it. * In :meth:`flax.linen.Dropout`, the ``deterministic`` argument is required to @@ -130,7 +130,7 @@ After creating your model: * Instantiate the model. * Then, in the :meth:`flax.linen.init()` call, set ``training=False``. * Finally, extract the ``params`` from the - `variable dictionary `__. + `variable dictionary `__. Here, the main difference between the code without Flax ``Dropout`` and with ``Dropout`` is that the ``training`` (or ``train``) argument must be diff --git a/docs/guides/ensembling.rst b/docs/guides/ensembling.rst index 9aa7cbbcde..3f3d04cb93 100644 --- a/docs/guides/ensembling.rst +++ b/docs/guides/ensembling.rst @@ -273,6 +273,6 @@ directly. .. _jax.pmap(): https://jax.readthedocs.io/en/latest/jax.html#jax.pmap .. |jax.lax.pmean()| replace:: ``jax.lax.pmean()`` .. _jax.lax.pmean(): https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.pmean.html -.. _Module.init: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.init +.. _Module.init: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init .. _`JIT mechanics: tracing and static variables`: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#JIT-mechanics:-tracing-and-static-variables .. _`MNIST example`: https://github.com/google/flax/blob/main/examples/mnist/train.py diff --git a/docs/guides/flax_on_pjit.ipynb b/docs/guides/flax_on_pjit.ipynb index ee2717b168..785cf16887 100644 --- a/docs/guides/flax_on_pjit.ipynb +++ b/docs/guides/flax_on_pjit.ipynb @@ -8,7 +8,7 @@ "source": [ "# Scale up Flax Modules on multiple devices\n", "\n", - "This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html)." + "This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html)." ] }, { @@ -24,7 +24,7 @@ "\n", "Flax provides several functionalities that can help you use auto-SPMD on [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html), including:\n", "\n", - "1. An interface to specify partitions of your data when defining [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module).\n", + "1. An interface to specify partitions of your data when defining [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n", "2. Utility functions to generate the sharding information that `jax.jit` requires to run.\n", "3. An interface to customize your axis names called \"logical axis annotations\" to decouple both your Module code and partition plan to experiment with different partition layouts more easily.\n", "\n", @@ -262,7 +262,7 @@ "source": [ "## Define a model with `flax.linen.scan` lifted transformation\n", "\n", - "Having created `DotReluDot`, you can now define the `MLP` model (by subclassing [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module)) as multiple layers of `DotReluDot`.\n", + "Having created `DotReluDot`, you can now define the `MLP` model (by subclassing [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module)) as multiple layers of `DotReluDot`.\n", "\n", "To replicate identical layers, you can either use [`flax.linen.scan`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.scan.html), or a for-loop:\n", "\n", @@ -403,7 +403,7 @@ "source": [ "### The output's sharding\n", "\n", - "You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree.\n", + "You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree.\n", "\n", "To achieve this, luckily, you don't have to hardcode the output's sharding by hand. Instead, you can:\n", "\n", @@ -1137,7 +1137,7 @@ "id": "58475fffb2de" }, "source": [ - "You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous (\"non-logical\") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.apply) the same way in the above above." + "You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous (\"non-logical\") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) the same way in the above above." ] }, { diff --git a/docs/guides/flax_on_pjit.md b/docs/guides/flax_on_pjit.md index 7f1689440d..acb3e0a950 100644 --- a/docs/guides/flax_on_pjit.md +++ b/docs/guides/flax_on_pjit.md @@ -12,7 +12,7 @@ jupytext: # Scale up Flax Modules on multiple devices -This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html). +This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html). +++ {"id": "b1e0e5fc8bc1"} @@ -22,7 +22,7 @@ This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/l Flax provides several functionalities that can help you use auto-SPMD on [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html), including: -1. An interface to specify partitions of your data when defining [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module). +1. An interface to specify partitions of your data when defining [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html). 2. Utility functions to generate the sharding information that `jax.jit` requires to run. 3. An interface to customize your axis names called "logical axis annotations" to decouple both your Module code and partition plan to experiment with different partition layouts more easily. @@ -174,7 +174,7 @@ For example: ## Define a model with `flax.linen.scan` lifted transformation -Having created `DotReluDot`, you can now define the `MLP` model (by subclassing [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module)) as multiple layers of `DotReluDot`. +Having created `DotReluDot`, you can now define the `MLP` model (by subclassing [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module)) as multiple layers of `DotReluDot`. To replicate identical layers, you can either use [`flax.linen.scan`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.scan.html), or a for-loop: @@ -248,7 +248,7 @@ jax.debug.visualize_array_sharding(x) ### The output's sharding -You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree. +You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree. To achieve this, luckily, you don't have to hardcode the output's sharding by hand. Instead, you can: @@ -516,7 +516,7 @@ print('sharding annotations are mesh-specific: ', +++ {"id": "58475fffb2de"} -You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous ("non-logical") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.apply) the same way in the above above. +You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous ("non-logical") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) the same way in the above above. ```{code-cell} :id: 589ff774bb4c diff --git a/docs/guides/linen_upgrade_guide.rst b/docs/guides/linen_upgrade_guide.rst index b758a785a5..33185c65a5 100644 --- a/docs/guides/linen_upgrade_guide.rst +++ b/docs/guides/linen_upgrade_guide.rst @@ -482,29 +482,29 @@ TODO: Given an example of ``jax.scan_in_dim`` (pre-Linen) vs. ``nn.scan`` (Linen). .. _`Should I use setup or nn.compact?`: https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html -.. _`Variables documentation`: https://flax.readthedocs.io/en/latest/flax.linen.html#module-flax.core.variables +.. _`Variables documentation`: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html .. _`TrainState`: https://flax.readthedocs.io/en/latest/flax.training.html#train-state .. _`Upgrading my codebase to Optax`: https://flax.readthedocs.io/en/latest/guides/optax_update_guide.html .. _`Lifted transformations`: https://flax.readthedocs.io/en/latest/developer_notes/lift.html .. |@compact| replace:: ``@compact`` -.. _@compact: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.compact +.. _@compact: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact .. |init| replace:: ``init`` -.. _init: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.init +.. _init: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init .. |init_with_output| replace:: ``init_with_output`` -.. _init_with_output: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.init_with_output +.. _init_with_output: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init_with_output .. |jax.jit| replace:: ``jax.jit`` .. _jax.jit: https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit .. |self.param| replace:: ``self.param`` -.. _self.param: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.param +.. _self.param: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.param .. |setup| replace:: ``setup`` -.. _setup: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.setup +.. _setup: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.setup .. |@flax.struct.dataclass| replace:: ``@flax.struct.dataclass`` .. _@flax.struct.dataclass: https://flax.readthedocs.io/en/latest/flax.struct.html#flax.struct.dataclass diff --git a/docs/guides/transfer_learning.ipynb b/docs/guides/transfer_learning.ipynb index 5a933631be..9efd9099c4 100644 --- a/docs/guides/transfer_learning.ipynb +++ b/docs/guides/transfer_learning.ipynb @@ -92,7 +92,7 @@ "\n", "Calling `load_model` from the snippet above returns the `FlaxCLIPModule`, which is composed of `text_model` and `vision_model` submodules.\n", "\n", - "An easy way to extract the `vision_model` sub-Module defined inside `.setup()` and its variables is to use [`flax.linen.Module.bind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.bind) on the `clip` Module immediately followed by [`flax.linen.Module.unbind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.unbind) on the `vision_model` sub-Module." + "An easy way to extract the `vision_model` sub-Module defined inside `.setup()` and its variables is to use [`flax.linen.Module.bind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.bind) on the `clip` Module immediately followed by [`flax.linen.Module.unbind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.unbind) on the `vision_model` sub-Module." ] }, { diff --git a/docs/guides/transfer_learning.md b/docs/guides/transfer_learning.md index de977521b4..ecab0ea35d 100644 --- a/docs/guides/transfer_learning.md +++ b/docs/guides/transfer_learning.md @@ -68,7 +68,7 @@ Note that `FlaxCLIPVisionModel` itself is not a Flax `Module` which is why we ne Calling `load_model` from the snippet above returns the `FlaxCLIPModule`, which is composed of `text_model` and `vision_model` submodules. -An easy way to extract the `vision_model` sub-Module defined inside `.setup()` and its variables is to use [`flax.linen.Module.bind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.bind) on the `clip` Module immediately followed by [`flax.linen.Module.unbind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.unbind) on the `vision_model` sub-Module. +An easy way to extract the `vision_model` sub-Module defined inside `.setup()` and its variables is to use [`flax.linen.Module.bind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.bind) on the `clip` Module immediately followed by [`flax.linen.Module.unbind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.unbind) on the `vision_model` sub-Module. ```{code-cell} ipython3 import flax.linen as nn diff --git a/docs/notebooks/flax_sharp_bits.ipynb b/docs/notebooks/flax_sharp_bits.ipynb index 1570a162f3..76c5a4f45c 100644 --- a/docs/notebooks/flax_sharp_bits.ipynb +++ b/docs/notebooks/flax_sharp_bits.ipynb @@ -38,14 +38,14 @@ "\n", "1. Start with [`jax.random.split()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax-random-split) to explicitly create PRNG keys for `'params'` and `'dropout'`.\n", "2. Add the [`flax.linen.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.Dropout.html#flax.linen.Dropout) layer(s) to your model (subclassed from Flax [`Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)).\n", - "3. When initializing the model ([`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply)), there's no need to pass in an extra `'dropout'` PRNG key—just the `'params'` key like in a \"simpler\" model.\n", - "4. During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply), pass in `rngs={'dropout': dropout_key}`.\n", + "3. When initializing the model ([`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html)), there's no need to pass in an extra `'dropout'` PRNG key—just the `'params'` key like in a \"simpler\" model.\n", + "4. During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html), pass in `rngs={'dropout': dropout_key}`.\n", "\n", "Check out a full example below.\n", "\n", "### Why this works\n", "\n", - "- Internally, `flax.linen.Dropout` makes use of [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.make_rng) to create a key for dropout (check out the [source code](https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/flax/linen/stochastic.py#L72)).\n", + "- Internally, `flax.linen.Dropout` makes use of [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) to create a key for dropout (check out the [source code](https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/flax/linen/stochastic.py#L72)).\n", "- Every time `make_rng` is called (in this case, it's done implicitly in `Dropout`), you get a new PRNG key split from the main/root PRNG key.\n", "- `make_rng` still _guarantees full reproducibility_.\n", "\n", @@ -55,13 +55,13 @@ "\n", "> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.PRNGKey(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).\n", "\n", - "Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to \"pull PRNG keys\". `make_rng` guarantees to provide a unique key each time you call it.\n", + "Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to \"pull PRNG keys\". `make_rng` guarantees to provide a unique key each time you call it.\n", "\n", - "> Note: Recall that [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module) is the base class for all neural network modules. All layers and models are subclassed from it.\n", + "> Note: Recall that [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) is the base class for all neural network modules. All layers and models are subclassed from it.\n", "\n", "### Example\n", "\n", - "Remember that each of the Flax PRNG streams has a name. The example below uses the `'params'` stream for initializing parameters, as well as the `'dropout'` stream. The PRNG key provided to [`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply) is the one that seeds the `'params'` PRNG key stream. To draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed that stream (`'dropout'`) when you call `Module.apply()`." + "Remember that each of the Flax PRNG streams has a name. The example below uses the `'params'` stream for initializing parameters, as well as the `'dropout'` stream. The PRNG key provided to [`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html) is the one that seeds the `'params'` PRNG key stream. To draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed that stream (`'dropout'`) when you call `Module.apply()`." ] }, { diff --git a/docs/notebooks/flax_sharp_bits.md b/docs/notebooks/flax_sharp_bits.md index cfe82e84d6..ac5530270f 100644 --- a/docs/notebooks/flax_sharp_bits.md +++ b/docs/notebooks/flax_sharp_bits.md @@ -30,14 +30,14 @@ When working on a model with dropout (subclassed from [Flax `Module`](https://fl 1. Start with [`jax.random.split()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax-random-split) to explicitly create PRNG keys for `'params'` and `'dropout'`. 2. Add the [`flax.linen.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.Dropout.html#flax.linen.Dropout) layer(s) to your model (subclassed from Flax [`Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)). -3. When initializing the model ([`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply)), there's no need to pass in an extra `'dropout'` PRNG key—just the `'params'` key like in a "simpler" model. -4. During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply), pass in `rngs={'dropout': dropout_key}`. +3. When initializing the model ([`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html)), there's no need to pass in an extra `'dropout'` PRNG key—just the `'params'` key like in a "simpler" model. +4. During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html), pass in `rngs={'dropout': dropout_key}`. Check out a full example below. ### Why this works -- Internally, `flax.linen.Dropout` makes use of [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.make_rng) to create a key for dropout (check out the [source code](https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/flax/linen/stochastic.py#L72)). +- Internally, `flax.linen.Dropout` makes use of [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) to create a key for dropout (check out the [source code](https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/flax/linen/stochastic.py#L72)). - Every time `make_rng` is called (in this case, it's done implicitly in `Dropout`), you get a new PRNG key split from the main/root PRNG key. - `make_rng` still _guarantees full reproducibility_. @@ -47,13 +47,13 @@ The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) > Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.PRNGKey(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers). -Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to "pull PRNG keys". `make_rng` guarantees to provide a unique key each time you call it. +Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to "pull PRNG keys". `make_rng` guarantees to provide a unique key each time you call it. -> Note: Recall that [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module) is the base class for all neural network modules. All layers and models are subclassed from it. +> Note: Recall that [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) is the base class for all neural network modules. All layers and models are subclassed from it. ### Example -Remember that each of the Flax PRNG streams has a name. The example below uses the `'params'` stream for initializing parameters, as well as the `'dropout'` stream. The PRNG key provided to [`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply) is the one that seeds the `'params'` PRNG key stream. To draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed that stream (`'dropout'`) when you call `Module.apply()`. +Remember that each of the Flax PRNG streams has a name. The example below uses the `'params'` stream for initializing parameters, as well as the `'dropout'` stream. The PRNG key provided to [`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html) is the one that seeds the `'params'` PRNG key stream. To draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed that stream (`'dropout'`) when you call `Module.apply()`. ```{code-cell} ipython3 # Setup. diff --git a/docs/philosophy.md b/docs/philosophy.md index e353829e1f..e9a5dcd5d7 100644 --- a/docs/philosophy.md +++ b/docs/philosophy.md @@ -29,7 +29,7 @@ Flax is a neural network library built on [JAX](https://jax.readthedocs.io) that growing set of users, most notably in the JAX submissions for the MLPerf 0.7 benchmark. Our experience over the last year (and many conversations with users and JAX core devs) has guided a redesign of the API called -[Linen](https://github.com/google/flax/blob/main/flax/linen/README.md) ([`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html)) in response to the following basic design questions. +[Linen](https://github.com/google/flax/blob/main/flax/linen/README.md) ([`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html)) in response to the following basic design questions. ### How does a neural network library benefit from being built on JAX and leverage JAX’s unique strengths? @@ -70,7 +70,7 @@ transform options to the various variable collections and PRNG state. This unleashes the flexibility and strength of [JAX transformations](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) – for example, one can achieve either device-parallel training or per-device ensembling by using [`jax.pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) in different ways, without any explicit -library support. Moreover, **within [Modules](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module)**, we expose lightweight +library support. Moreover, **within [Modules](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module)**, we expose lightweight wrappers around the complex JAX transforms such as [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) and [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) that annotate how each variable collection is to be transformed by JAX. Importantly, we handle the nontrivial cases of creating new variables @@ -96,7 +96,7 @@ these various kinds under the single vague rubric of “state”, but keep different logical types of variables separate that can be treated differently under JAX transformations and under mutations (e.g. training vs prediction). Similarly, we allow for multiple separate named PRNG -chains inside [Modules](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module) for separate treatment of randomness for different +chains inside [Modules](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) for separate treatment of randomness for different applications such as initialization, dropout, sampling, etc. At every stage the data associated with a neural net is not kept in a @@ -113,7 +113,7 @@ dicts, you can use any (non-JAX-aware) serialization library directly. To be broadly useful to the JAX ecosystem, users shouldn’t need to heavily refactor their code in order to add “trainability” for a given numerical task. _“The library should not get in the way.”_ Utilizing -purely functional code from within Linen is trivial: [Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module) +purely functional code from within Linen is trivial: [Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) implementations are just JAX code with named variables. Using Linen Modules inside otherwise purely functional code can be as simple as using a single top-level Module transformation to allow initialization diff --git a/flax/linen/README.md b/flax/linen/README.md index 4f523779f2..3e9b401496 100644 --- a/flax/linen/README.md +++ b/flax/linen/README.md @@ -9,7 +9,7 @@ The Linen Module API is stable and currently recommended for new projects. We ar Please open a [discussion](https://github.com/google/flax/discussions) if you have any questions or thoughts. -**See the [Linen API reference docs](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html)**, or take a look at our additional material: +**See the [Linen API reference docs](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html)**, or take a look at our additional material: * 2-page intro to the [Linen Design Principles](https://docs.google.com/document/d/1ZlL_4bXCw5Xl0WstQw1GpnZqfb9JFOeUGAPcBVk-kn8) * [Slides from a talk to the JAX core team](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0)