From f16e6d6c87b883739e05d34afa713d9e8efcaaba Mon Sep 17 00:00:00 2001
From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com>
Date: Thu, 10 Oct 2024 22:04:33 +0000
Subject: [PATCH] Upgrade Flax NNX Randomness
---
docs_nnx/guides/randomness.ipynb | 141 ++++++++++++++++++++-----------
docs_nnx/guides/randomness.md | 141 ++++++++++++++++++++-----------
2 files changed, 188 insertions(+), 94 deletions(-)
diff --git a/docs_nnx/guides/randomness.ipynb b/docs_nnx/guides/randomness.ipynb
index a0d41462e6..9e7366584f 100644
--- a/docs_nnx/guides/randomness.ipynb
+++ b/docs_nnx/guides/randomness.ipynb
@@ -6,7 +6,19 @@
"source": [
"# Randomness\n",
"\n",
- "Random state in Flax NNX is radically simplified compared to systems like Haiku/Flax Linen in that it defines \"random state as object state\". In essence, this means that random state is just another type of state, it's stored in Variables and held by the models themselves. The main characteristic of the RNG system in Flax NNX is that its **explicit**, its **order-based**, and uses **dynamic counters**. This is a bit different from [Flax Linen's RNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) which is (path + order)-based and uses a static counters."
+ "Random state handling in Flax NNX was radically simplified compared to systems like Haiku and Flax Linen because Flax NNX _defines the random state as an object state_. In essence, this means that in Flax NNX, the random state is: 1) just another type of state; 2) stored in `nnx.Variable`s; and 3) held by the models themselves.\n",
+ "\n",
+ "The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics:\n",
+ "\n",
+ "- It is **explicit**.\n",
+ "- It is **order-based**.\n",
+ "- It uses **dynamic counters**.\n",
+ "\n",
+ "This is a bit different from [Flax Linen's PRNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html), which is `(path + order)`-based, and uses static counters.\n",
+ "\n",
+ "> **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).\n",
+ "\n",
+ "Let’ start with some necessary imports:"
]
},
{
@@ -24,16 +36,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Rngs, RngStream, and RngState\n",
- "Flax NNX provides the `nnx.Rngs` type a the main convenience API for managing random state. Following Flax Linen's footsteps, `Rngs` has the ability to create multiple named RNG streams, each with its own state, for the purpose of tight control over randomness in the context of JAX transforms. Here's a breakdown of the main RNG-related types in Flax NNX:\n",
+ "## `Rngs`, `RngStream`, and `RngState`\n",
+ "\n",
+ "In Flax NNX, the `nnx.Rngs` type is the primary convenience API for managing the random state(s). Following Flax Linen's footsteps, `nnx.Rngs` have the ability to create multiple named PRNG key [streams](https://jax.readthedocs.io/en/latest/jep/263-prng.html), each with its own state, for the purpose of having tight control over randomness in the context of [JAX transformations (transforms)](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n",
"\n",
- "* **Rngs**: The main user interface. It defines a set of named `RngStream` objects.\n",
- "* **nnx.RngStream**: A object that can generate a stream of RNG keys. It holds a root `key` and a `count` inside a `RngKey` and `RngCount` Variables respectively. When a new key is generated, the count is incremented.\n",
- "* **nnx.RngState**: The base type for all RNG-related state.\n",
- " * **nnx.RngKey**: Variable type for holding RNG keys, it includes a `tag` attribute containing the name of the stream.\n",
- " * **nnx.RngCount**: Variable type for holding RNG counts, it includes a `tag` attribute containing the name of the stream.\n",
+ "Here are the main PRNG-related types in Flax NNX:\n",
"\n",
- "To create an `Rngs` object you can simply pass a integer seed or `jax.random.key` instance to any keyword argument of your choice in the constructor. Here's an example:"
+ "* **`nnx.Rngs`**: The main user interface. It defines a set of named `nnx.RngStream` objects.\n",
+ "* **`nnx.RngStream`**: An object that can generate a stream of PRNG keys. It holds a root `key` and a `count` inside an `nnx.RngKey` and `nnx.RngCount` `nnx.Variable`s, respectively. When a new key is generated, the count is incremented.\n",
+ "* **`nnx.RngState`**: The base type for all RNG-related states.\n",
+ " * **`nnx.RngKey`**: NNX Variable type for holding PRNG keys. It includes a `tag` attribute containing the name of the PRNG key stream.\n",
+ " * **`nnx.RngCount`**: NNX Variable type for holding PRNG counts. It includes a `tag` attribute containing the PRNG key stream name.\n",
+ "\n",
+ "To create an `nnx.Rngs` object you can simply pass an integer seed or `jax.random.key` instance to any keyword argument of your choice in the constructor.\n",
+ "\n",
+ "Here's an example:"
]
},
{
@@ -76,7 +93,7 @@
"id": "b2c2ac86",
"metadata": {},
"source": [
- "Notice that the `key` and `count` Variables contain the stream name in a `tag` attribute. This is primarily used for filtering as we'll see later.\n",
+ "Notice that the `key` and `count` `nnx.Variable`s contain the PRNG key stream name in a `tag` attribute. This is primarily used for filtering as we'll see later.\n",
"\n",
"To generate new keys, you can access one of the streams and use its `__call__` method with no arguments. This will return a new key by using `random.fold_in` with the current `key` and `count`. The `count` is then incremented so that subsequent calls will return new keys."
]
@@ -122,17 +139,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note that the `key` attribute does not change when a new keys are generated.\n",
+ "Note that the `key` attribute does not change when new PRNG keys are generated.\n",
+ "\n",
+ "### Standard PRNG key stream names\n",
"\n",
- "### Standard stream names\n",
- "There are only two standard stream names used by Flax NNX's built-in layers, shown in the table below:\n",
+ "There are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:\n",
"\n",
- "| Name | Description |\n",
- "|----------|-------------------------------------------|\n",
- "| `params` | Used for parameter initialization |\n",
- "| `dropout`| Used by `Dropout` to create dropout masks |\n",
+ "| PRNG key stream name | Description |\n",
+ "|----------------------|-----------------------------------------------|\n",
+ "| `params` | Used for parameter initialization |\n",
+ "| `dropout` | Used by `nnx.Dropout` to create dropout masks |\n",
"\n",
- "`params` is used my most of the standard layers (`Linear`, `Conv`, `MultiHeadAttention`, etc) during construction to initialize their parameters. `dropout` is used by the `Dropout` and `MultiHeadAttention` to generate dropout masks. Here's a simple example of a model using `params` and `dropout` streams:"
+ "- `params` is used by most of the standard layers (such as `nnx.Linear`, `nnx.Conv`, `nnx.MultiHeadAttention`, and so on) during the construction to initialize their parameters.\n",
+ "- `dropout` is used by `nnx.Dropout` and `nnx.MultiHeadAttention` to generate dropout masks.\n",
+ "\n",
+ "Below is a simple example of a model that uses `params` and `dropout` PRNG key streams:"
]
},
{
@@ -167,9 +188,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Default stream\n",
- "One of the downsides of having named streams is that the user needs to know all the possible names that a model will use when creating the `Rngs` object. While this could be solved with some documentation, Flax NNX provides a `default` stream that can be\n",
- "be used as a fallback when a stream is not found. To use the default stream you can simply pass an integer seed or `jax.random.key` as the first positional argument."
+ "### Default PRNG key stream\n",
+ "\n",
+ "One of the downsides of having named streams is that the user needs to know all the possible names that a model will use when creating the `nnx.Rngs` object. While this could be solved with some documentation, Flax NNX provides a `default` stream that can be\n",
+ "be used as a fallback when a stream is not found. To use the default PRNG key stream, you can simply pass an integer seed or `jax.random.key` as the first positional argument."
]
},
{
@@ -205,11 +227,11 @@
"source": [
"rngs = nnx.Rngs(0, params=1)\n",
"\n",
- "key1 = rngs.params() # call params\n",
- "key2 = rngs.dropout() # fallback to default\n",
- "key3 = rngs() # call default directly\n",
+ "key1 = rngs.params() # Call params.\n",
+ "key2 = rngs.dropout() # Fallback to the default stream.\n",
+ "key3 = rngs() # Call the default stream directly.\n",
"\n",
- "# test with Model that uses params and dropout\n",
+ "# Test with the `Model` that uses `params` and `dropout`.\n",
"model = Model(rngs)\n",
"y = model(jnp.ones((1, 20)))\n",
"\n",
@@ -221,10 +243,10 @@
"id": "e81510e3",
"metadata": {},
"source": [
- "As shown above, a key from the `default` stream can also be generated by calling the `Rngs` object itself.\n",
+ "As shown above, a PRNG key from the `default` stream can also be generated by calling the `nnx.Rngs` object itself.\n",
"\n",
"> **Note**\n",
- ">
For big projects it is recommended to use named streams to avoid potential conflicts. For small projects or quick prototyping just using the `default` stream is a good choice."
+ ">
For large projects it is recommended to use named streams to avoid potential conflicts. For small projects or quick prototyping just using the `default` stream is a good choice."
]
},
{
@@ -234,7 +256,7 @@
"source": [
"## Filtering random state\n",
"\n",
- "Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:"
+ "Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:"
]
},
{
@@ -271,12 +293,12 @@
"source": [
"model = Model(nnx.Rngs(params=0, dropout=1))\n",
"\n",
- "rng_state = nnx.state(model, nnx.RngState) # all random state\n",
- "key_state = nnx.state(model, nnx.RngKey) # only keys\n",
- "count_state = nnx.state(model, nnx.RngCount) # only counts\n",
- "rng_params_state = nnx.state(model, 'params') # only params\n",
- "rng_dropout_state = nnx.state(model, 'dropout') # only dropout\n",
- "params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys\n",
+ "rng_state = nnx.state(model, nnx.RngState) # All random states.\n",
+ "key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.\n",
+ "count_state = nnx.state(model, nnx.RngCount) # Only counts.\n",
+ "rng_params_state = nnx.state(model, 'params') # Only `params`.\n",
+ "rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.\n",
+ "params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # `Params` PRNG keys.\n",
"\n",
"nnx.display(params_key_state)"
]
@@ -286,13 +308,17 @@
"metadata": {},
"source": [
"## Reseeding\n",
- "In Haiku and Flax Linen, random state explicitly passed to `apply` each time before calling the model. This makes it easy to control the randomness of the model when needed e.g. for reproducibility. In Flax NNX there are two ways to approach this:\n",
- "1. By passing an `Rngs` object through the `__call__` stack manually. Standard layers like `Dropout` and `MultiHeadAttention` accept an `rngs` argument in case you want to have tight control over the random state.\n",
+ "\n",
+ "In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` each time before you call the model. This makes it easy to control the randomness of the model when needed (for example, for reproducibility).\n",
+ "\n",
+ "In Flax NNX, there are two ways to approach this:\n",
+ "\n",
+ "1. By passing an `nnx.Rngs` object through the `__call__` stack manually. Standard layers like `nnx.Dropout` and `nnx.MultiHeadAttention` accept the `rngs` argument if you want to have tight control over the random state.\n",
"2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.\n",
"\n",
- "`reseed` is a function that accepts an arbitrary graph node (this include pytrees of Flax NNX Modules) and some keyword arguments containing the new seed or key value for the `RngStream`s specified by the argument names. `reseed` will then traverse the graph and update the random state of the matching `RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.\n",
+ "`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.\n",
"\n",
- "Here's an example of how to using `reseed` to reset the random state of the `Dropout` layer and verify that the computation is identical to the first time the model was called:"
+ "Here's an example of how to use `nnx.reseed` to reset the random state of the `nnx.Dropout` layer and verify that the computation is identical to the first time the model was called:"
]
},
{
@@ -318,8 +344,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Splitting Rngs\n",
- "When interacting with transforms like `vmap` or `pmap` it is often necessary to split the random state such that each replica has its own unique state. This can be done in two way, either by manually splitting a key before passing it to one of the `Rngs` streams, or using the `nnx.split_rngs` decorator which will automatically split the random state of any `RngStream`s found in the inputs of the function and automatically \"lower\" them once the function call ends. `split_rngs` is more convenient as it works nicely with transforms so we'll show an example of that here:"
+ "## Splitting PRNG keys\n",
+ "\n",
+ "When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the random state such that each replica has its own unique state. This can be done in two ways:\n",
+ "\n",
+ "- By manually splitting a key before passing it to one of the `nnx.Rngs` streams; or\n",
+ "- By using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically \"lower\" them once the function call ends.\n",
+ "\n",
+ "It is more convenient to use `nnx.split_rngs`, since it works nicely with Flax NNX transforms, so here’s one example:"
]
},
{
@@ -410,7 +442,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note that `split_rngs` allows passing a Filter to the `only` keyword argument to select the `RngStream`s that should be split when inside the function, in this case we only split the `dropout` stream."
+ "> **Note:** `nnx.split_rngs` allows passing an NNX [`Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to the `only` keyword argument in order to select the `nnx.RngStream`s that should be split when inside the function. In such a case, you only need to split the `dropout` PRNG key stream."
]
},
{
@@ -419,10 +451,17 @@
"metadata": {},
"source": [
"## Transforms\n",
- "As stated before, in Flax NNX random state is just another type of state, this means that there is nothing special about it regarding transforms. This means that you should be able to use the state handling APIs of each transform to get the results you want. In this section we will two examples of using random state in transforms, one with `pmap` where we'll see how to split the RNG state, and another one with `scan` where we'll see how to freeze the RNG state.\n",
+ "\n",
+ "As stated before, in Flax NNX the random state is just another type of state. This means that there is nothing special about it when it comes to Flax NNX transforms, which means that you should be able to use the Flax NNX state handling APIs of each transform to get the results you want.\n",
+ "\n",
+ "In this section, you will go through two examples of using the random state in Flax NNX transforms - one with `nnx.pmap`, where you will learn how to split the PRNG state, and another one with `nnx.scan`, where you will freeze the PRNG state.\n",
"\n",
"### Data parallel dropout\n",
- "In the first example we'll explore how to use `pmap` to call our `Model` in a data parallel context. Since Model uses `Dropout` we'll need to split the random state of the `dropout` to ensure that each replica gets different dropout masks. `StateAxes` is passed to `in_axes` to specify that the `model`'s `dropout` stream will be parallelized across axis `0`, and the rest of its state will be replicated. `split_rngs` is used to split the keys of the `dropout` streams into N unique keys, one for each replica."
+ "\n",
+ "In the first example, you’ll explore how to use `nnx.pmap` to call the `nnx.Model` in a data parallel context.\n",
+ "- Since the `nnx.Model` uses `nnx.Dropout`, you’ll need to split the random state of the `dropout` to ensure that each replica gets different dropout masks.\n",
+ "- `nnx.StateAxes` is passed to `in_axes` to specify that the `model`'s `dropout` PRNG key stream will be parallelized across axis `0`, and the rest of its state will be replicated.\n",
+ "- `nnx.split_rngs` is used to split the keys of the `dropout` PRNG key streams into N unique keys, one for each replica."
]
},
{
@@ -459,7 +498,13 @@
"metadata": {},
"source": [
"### Recurrent dropout\n",
- "Next we will explore how to implement a RNNCell that uses recurrent dropout. To do this we will simply create a `Dropout` layer that will sample keys from a custom `recurrent_dropout` stream, and we will apply dropout to the hidden state `h` of the RNNCell. A `initial_state` method will be defined to create the initial state of the RNNCell."
+ "\n",
+ "Next, let’s explore how to implement an `RNNCell` that uses a recurrent dropout. To do this:\n",
+ "\n",
+ "- First, you will create an `nnx.Dropout` layer that will sample PRNG keys from a custom `recurrent_dropout` stream.\n",
+ "- You will apply dropout (`drop`) to the hidden state `h` of the `RNNCell`.\n",
+ "- Then, define an `initial_state` function to create the initial state of the `RNNCell`.\n",
+ "- Finally, instantiate `RNNCell`."
]
},
{
@@ -478,7 +523,7 @@
" self.count = Count(jnp.array(0, jnp.uint32))\n",
"\n",
" def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:\n",
- " h = self.drop(h) # recurrent dropout\n",
+ " h = self.drop(h) # Recurrent dropout.\n",
" y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))\n",
" self.count += 1\n",
" return y, y\n",
@@ -493,7 +538,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Next we will use `scan` over an `unroll` function to implement the `rnn_forward` operation. The key ingredient of recurrent dropout is to apply the same dropout mask across all time steps, to achieve this we'll pass `StateAxes` to `scan`'s `in_axes` specifying that the `cell`'s `recurrent_dropout` stream will be broadcasted, and the rest of the cell's state will be carried over. Also, the hidden state `h` will be the `scan`'s `Carry` variable, and the sequence `x` will be scanned over its axis `1`."
+ "Next, you will use `nnx.scan` over an `unroll` function to implement the `rnn_forward` operation:\n",
+ "- The key ingredient of recurrent dropout is to apply the same dropout mask across all time steps. Therefore, to achieve this you will pass `nnx.StateAxes` to `nnx.scan`'s `in_axes`, specifying that the `cell`'s `recurrent_dropout` PRNG stream will be broadcast, and the rest of the `RNNCell`'s state will be carried over.\n",
+ "- Also, the hidden state `h` will be the `nnx.scan`'s `Carry` variable, and the sequence `x` will be `scan`ned over its axis `1`."
]
},
{
@@ -515,7 +562,7 @@
"def rnn_forward(cell: RNNCell, x: jax.Array):\n",
" h = cell.initial_state(batch_size=x.shape[0])\n",
"\n",
- " # broadcast 'recurrent_dropout' RNG state to have the same mask on every step\n",
+ " # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.\n",
" state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})\n",
" @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))\n",
" def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:\n",
diff --git a/docs_nnx/guides/randomness.md b/docs_nnx/guides/randomness.md
index c4b9623ef3..6e2851305a 100644
--- a/docs_nnx/guides/randomness.md
+++ b/docs_nnx/guides/randomness.md
@@ -10,7 +10,19 @@ jupytext:
# Randomness
-Random state in Flax NNX is radically simplified compared to systems like Haiku/Flax Linen in that it defines "random state as object state". In essence, this means that random state is just another type of state, it's stored in Variables and held by the models themselves. The main characteristic of the RNG system in Flax NNX is that its **explicit**, its **order-based**, and uses **dynamic counters**. This is a bit different from [Flax Linen's RNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) which is (path + order)-based and uses a static counters.
+Random state handling in Flax NNX was radically simplified compared to systems like Haiku and Flax Linen because Flax NNX _defines the random state as an object state_. In essence, this means that in Flax NNX, the random state is: 1) just another type of state; 2) stored in `nnx.Variable`s; and 3) held by the models themselves.
+
+The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics:
+
+- It is **explicit**.
+- It is **order-based**.
+- It uses **dynamic counters**.
+
+This is a bit different from [Flax Linen's PRNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html), which is `(path + order)`-based, and uses static counters.
+
+> **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).
+
+Let’ start with some necessary imports:
```{code-cell} ipython3
from flax import nnx
@@ -18,23 +30,28 @@ import jax
from jax import random, numpy as jnp
```
-## Rngs, RngStream, and RngState
-Flax NNX provides the `nnx.Rngs` type a the main convenience API for managing random state. Following Flax Linen's footsteps, `Rngs` has the ability to create multiple named RNG streams, each with its own state, for the purpose of tight control over randomness in the context of JAX transforms. Here's a breakdown of the main RNG-related types in Flax NNX:
+## `Rngs`, `RngStream`, and `RngState`
+
+In Flax NNX, the `nnx.Rngs` type is the primary convenience API for managing the random state(s). Following Flax Linen's footsteps, `nnx.Rngs` have the ability to create multiple named PRNG key [streams](https://jax.readthedocs.io/en/latest/jep/263-prng.html), each with its own state, for the purpose of having tight control over randomness in the context of [JAX transformations (transforms)](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).
-* **Rngs**: The main user interface. It defines a set of named `RngStream` objects.
-* **nnx.RngStream**: A object that can generate a stream of RNG keys. It holds a root `key` and a `count` inside a `RngKey` and `RngCount` Variables respectively. When a new key is generated, the count is incremented.
-* **nnx.RngState**: The base type for all RNG-related state.
- * **nnx.RngKey**: Variable type for holding RNG keys, it includes a `tag` attribute containing the name of the stream.
- * **nnx.RngCount**: Variable type for holding RNG counts, it includes a `tag` attribute containing the name of the stream.
+Here are the main PRNG-related types in Flax NNX:
-To create an `Rngs` object you can simply pass a integer seed or `jax.random.key` instance to any keyword argument of your choice in the constructor. Here's an example:
+* **`nnx.Rngs`**: The main user interface. It defines a set of named `nnx.RngStream` objects.
+* **`nnx.RngStream`**: An object that can generate a stream of PRNG keys. It holds a root `key` and a `count` inside an `nnx.RngKey` and `nnx.RngCount` `nnx.Variable`s, respectively. When a new key is generated, the count is incremented.
+* **`nnx.RngState`**: The base type for all RNG-related states.
+ * **`nnx.RngKey`**: NNX Variable type for holding PRNG keys. It includes a `tag` attribute containing the name of the PRNG key stream.
+ * **`nnx.RngCount`**: NNX Variable type for holding PRNG counts. It includes a `tag` attribute containing the PRNG key stream name.
+
+To create an `nnx.Rngs` object you can simply pass an integer seed or `jax.random.key` instance to any keyword argument of your choice in the constructor.
+
+Here's an example:
```{code-cell} ipython3
rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)
```
-Notice that the `key` and `count` Variables contain the stream name in a `tag` attribute. This is primarily used for filtering as we'll see later.
+Notice that the `key` and `count` `nnx.Variable`s contain the PRNG key stream name in a `tag` attribute. This is primarily used for filtering as we'll see later.
To generate new keys, you can access one of the streams and use its `__call__` method with no arguments. This will return a new key by using `random.fold_in` with the current `key` and `count`. The `count` is then incremented so that subsequent calls will return new keys.
@@ -45,17 +62,21 @@ dropout_key = rngs.dropout()
nnx.display(rngs)
```
-Note that the `key` attribute does not change when a new keys are generated.
+Note that the `key` attribute does not change when new PRNG keys are generated.
+
+### Standard PRNG key stream names
-### Standard stream names
-There are only two standard stream names used by Flax NNX's built-in layers, shown in the table below:
+There are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:
-| Name | Description |
-|----------|-------------------------------------------|
-| `params` | Used for parameter initialization |
-| `dropout`| Used by `Dropout` to create dropout masks |
+| PRNG key stream name | Description |
+|----------------------|-----------------------------------------------|
+| `params` | Used for parameter initialization |
+| `dropout` | Used by `nnx.Dropout` to create dropout masks |
-`params` is used my most of the standard layers (`Linear`, `Conv`, `MultiHeadAttention`, etc) during construction to initialize their parameters. `dropout` is used by the `Dropout` and `MultiHeadAttention` to generate dropout masks. Here's a simple example of a model using `params` and `dropout` streams:
+- `params` is used by most of the standard layers (such as `nnx.Linear`, `nnx.Conv`, `nnx.MultiHeadAttention`, and so on) during the construction to initialize their parameters.
+- `dropout` is used by `nnx.Dropout` and `nnx.MultiHeadAttention` to generate dropout masks.
+
+Below is a simple example of a model that uses `params` and `dropout` PRNG key streams:
```{code-cell} ipython3
class Model(nnx.Module):
@@ -72,56 +93,61 @@ y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
```
-### Default stream
-One of the downsides of having named streams is that the user needs to know all the possible names that a model will use when creating the `Rngs` object. While this could be solved with some documentation, Flax NNX provides a `default` stream that can be
-be used as a fallback when a stream is not found. To use the default stream you can simply pass an integer seed or `jax.random.key` as the first positional argument.
+### Default PRNG key stream
+
+One of the downsides of having named streams is that the user needs to know all the possible names that a model will use when creating the `nnx.Rngs` object. While this could be solved with some documentation, Flax NNX provides a `default` stream that can be
+be used as a fallback when a stream is not found. To use the default PRNG key stream, you can simply pass an integer seed or `jax.random.key` as the first positional argument.
```{code-cell} ipython3
rngs = nnx.Rngs(0, params=1)
-key1 = rngs.params() # call params
-key2 = rngs.dropout() # fallback to default
-key3 = rngs() # call default directly
+key1 = rngs.params() # Call params.
+key2 = rngs.dropout() # Fallback to the default stream.
+key3 = rngs() # Call the default stream directly.
-# test with Model that uses params and dropout
+# Test with the `Model` that uses `params` and `dropout`.
model = Model(rngs)
y = model(jnp.ones((1, 20)))
nnx.display(rngs)
```
-As shown above, a key from the `default` stream can also be generated by calling the `Rngs` object itself.
+As shown above, a PRNG key from the `default` stream can also be generated by calling the `nnx.Rngs` object itself.
> **Note**
->
For big projects it is recommended to use named streams to avoid potential conflicts. For small projects or quick prototyping just using the `default` stream is a good choice.
+>
For large projects it is recommended to use named streams to avoid potential conflicts. For small projects or quick prototyping just using the `default` stream is a good choice.
+++
## Filtering random state
-Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
+Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
```{code-cell} ipython3
model = Model(nnx.Rngs(params=0, dropout=1))
-rng_state = nnx.state(model, nnx.RngState) # all random state
-key_state = nnx.state(model, nnx.RngKey) # only keys
-count_state = nnx.state(model, nnx.RngCount) # only counts
-rng_params_state = nnx.state(model, 'params') # only params
-rng_dropout_state = nnx.state(model, 'dropout') # only dropout
-params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys
+rng_state = nnx.state(model, nnx.RngState) # All random states.
+key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.
+count_state = nnx.state(model, nnx.RngCount) # Only counts.
+rng_params_state = nnx.state(model, 'params') # Only `params`.
+rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.
+params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # `Params` PRNG keys.
nnx.display(params_key_state)
```
## Reseeding
-In Haiku and Flax Linen, random state explicitly passed to `apply` each time before calling the model. This makes it easy to control the randomness of the model when needed e.g. for reproducibility. In Flax NNX there are two ways to approach this:
-1. By passing an `Rngs` object through the `__call__` stack manually. Standard layers like `Dropout` and `MultiHeadAttention` accept an `rngs` argument in case you want to have tight control over the random state.
+
+In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` each time before you call the model. This makes it easy to control the randomness of the model when needed (for example, for reproducibility).
+
+In Flax NNX, there are two ways to approach this:
+
+1. By passing an `nnx.Rngs` object through the `__call__` stack manually. Standard layers like `nnx.Dropout` and `nnx.MultiHeadAttention` accept the `rngs` argument if you want to have tight control over the random state.
2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.
-`reseed` is a function that accepts an arbitrary graph node (this include pytrees of Flax NNX Modules) and some keyword arguments containing the new seed or key value for the `RngStream`s specified by the argument names. `reseed` will then traverse the graph and update the random state of the matching `RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.
+`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.
-Here's an example of how to using `reseed` to reset the random state of the `Dropout` layer and verify that the computation is identical to the first time the model was called:
+Here's an example of how to use `nnx.reseed` to reset the random state of the `nnx.Dropout` layer and verify that the computation is identical to the first time the model was called:
```{code-cell} ipython3
model = Model(nnx.Rngs(params=0, dropout=1))
@@ -137,8 +163,14 @@ assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3) # same
```
-## Splitting Rngs
-When interacting with transforms like `vmap` or `pmap` it is often necessary to split the random state such that each replica has its own unique state. This can be done in two way, either by manually splitting a key before passing it to one of the `Rngs` streams, or using the `nnx.split_rngs` decorator which will automatically split the random state of any `RngStream`s found in the inputs of the function and automatically "lower" them once the function call ends. `split_rngs` is more convenient as it works nicely with transforms so we'll show an example of that here:
+## Splitting PRNG keys
+
+When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the random state such that each replica has its own unique state. This can be done in two ways:
+
+- By manually splitting a key before passing it to one of the `nnx.Rngs` streams; or
+- By using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends.
+
+It is more convenient to use `nnx.split_rngs`, since it works nicely with Flax NNX transforms, so here’s one example:
```{code-cell} ipython3
rngs = nnx.Rngs(params=0, dropout=1)
@@ -156,15 +188,22 @@ rngs.dropout() # works!
nnx.display(rngs)
```
-Note that `split_rngs` allows passing a Filter to the `only` keyword argument to select the `RngStream`s that should be split when inside the function, in this case we only split the `dropout` stream.
+> **Note:** `nnx.split_rngs` allows passing an NNX [`Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to the `only` keyword argument in order to select the `nnx.RngStream`s that should be split when inside the function. In such a case, you only need to split the `dropout` PRNG key stream.
+++
## Transforms
-As stated before, in Flax NNX random state is just another type of state, this means that there is nothing special about it regarding transforms. This means that you should be able to use the state handling APIs of each transform to get the results you want. In this section we will two examples of using random state in transforms, one with `pmap` where we'll see how to split the RNG state, and another one with `scan` where we'll see how to freeze the RNG state.
+
+As stated before, in Flax NNX the random state is just another type of state. This means that there is nothing special about it when it comes to Flax NNX transforms, which means that you should be able to use the Flax NNX state handling APIs of each transform to get the results you want.
+
+In this section, you will go through two examples of using the random state in Flax NNX transforms - one with `nnx.pmap`, where you will learn how to split the PRNG state, and another one with `nnx.scan`, where you will freeze the PRNG state.
### Data parallel dropout
-In the first example we'll explore how to use `pmap` to call our `Model` in a data parallel context. Since Model uses `Dropout` we'll need to split the random state of the `dropout` to ensure that each replica gets different dropout masks. `StateAxes` is passed to `in_axes` to specify that the `model`'s `dropout` stream will be parallelized across axis `0`, and the rest of its state will be replicated. `split_rngs` is used to split the keys of the `dropout` streams into N unique keys, one for each replica.
+
+In the first example, you’ll explore how to use `nnx.pmap` to call the `nnx.Model` in a data parallel context.
+- Since the `nnx.Model` uses `nnx.Dropout`, you’ll need to split the random state of the `dropout` to ensure that each replica gets different dropout masks.
+- `nnx.StateAxes` is passed to `in_axes` to specify that the `model`'s `dropout` PRNG key stream will be parallelized across axis `0`, and the rest of its state will be replicated.
+- `nnx.split_rngs` is used to split the keys of the `dropout` PRNG key streams into N unique keys, one for each replica.
```{code-cell} ipython3
model = Model(nnx.Rngs(params=0, dropout=1))
@@ -183,7 +222,13 @@ print(y.shape)
```
### Recurrent dropout
-Next we will explore how to implement a RNNCell that uses recurrent dropout. To do this we will simply create a `Dropout` layer that will sample keys from a custom `recurrent_dropout` stream, and we will apply dropout to the hidden state `h` of the RNNCell. A `initial_state` method will be defined to create the initial state of the RNNCell.
+
+Next, let’s explore how to implement an `RNNCell` that uses a recurrent dropout. To do this:
+
+- First, you will create an `nnx.Dropout` layer that will sample PRNG keys from a custom `recurrent_dropout` stream.
+- You will apply dropout (`drop`) to the hidden state `h` of the `RNNCell`.
+- Then, define an `initial_state` function to create the initial state of the `RNNCell`.
+- Finally, instantiate `RNNCell`.
```{code-cell} ipython3
class Count(nnx.Variable): pass
@@ -196,7 +241,7 @@ class RNNCell(nnx.Module):
self.count = Count(jnp.array(0, jnp.uint32))
def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
- h = self.drop(h) # recurrent dropout
+ h = self.drop(h) # Recurrent dropout.
y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
self.count += 1
return y, y
@@ -207,14 +252,16 @@ class RNNCell(nnx.Module):
cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))
```
-Next we will use `scan` over an `unroll` function to implement the `rnn_forward` operation. The key ingredient of recurrent dropout is to apply the same dropout mask across all time steps, to achieve this we'll pass `StateAxes` to `scan`'s `in_axes` specifying that the `cell`'s `recurrent_dropout` stream will be broadcasted, and the rest of the cell's state will be carried over. Also, the hidden state `h` will be the `scan`'s `Carry` variable, and the sequence `x` will be scanned over its axis `1`.
+Next, you will use `nnx.scan` over an `unroll` function to implement the `rnn_forward` operation:
+- The key ingredient of recurrent dropout is to apply the same dropout mask across all time steps. Therefore, to achieve this you will pass `nnx.StateAxes` to `nnx.scan`'s `in_axes`, specifying that the `cell`'s `recurrent_dropout` PRNG stream will be broadcast, and the rest of the `RNNCell`'s state will be carried over.
+- Also, the hidden state `h` will be the `nnx.scan`'s `Carry` variable, and the sequence `x` will be `scan`ned over its axis `1`.
```{code-cell} ipython3
@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
h = cell.initial_state(batch_size=x.shape[0])
- # broadcast 'recurrent_dropout' RNG state to have the same mask on every step
+ # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.
state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
@nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]: