From cea5fb018a0b34a6d488d852a68e0059d8940dfc Mon Sep 17 00:00:00 2001
From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com>
Date: Mon, 16 Sep 2024 21:12:18 +0000
Subject: [PATCH] Upgrade Flax NNX Filters doc
---
docs/nnx/filters_guide.ipynb | 80 ++++++++++++++++++++----------------
docs/nnx/filters_guide.md | 80 ++++++++++++++++++++----------------
2 files changed, 88 insertions(+), 72 deletions(-)
diff --git a/docs/nnx/filters_guide.ipynb b/docs/nnx/filters_guide.ipynb
index 21591226ac..bd2d9bbb4c 100644
--- a/docs/nnx/filters_guide.ipynb
+++ b/docs/nnx/filters_guide.ipynb
@@ -5,12 +5,20 @@
"id": "95b08e64",
"metadata": {},
"source": [
- "# Using Filters\n",
+ "# Using `Filter`s\n",
"\n",
"> **Attention**: This page relates to the new Flax NNX API.\n",
"\n",
- "Filters are used extensively in Flax NNX as a way to create `State` groups in APIs\n",
- "such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:"
+ "`Filter`s are used extensively in the Flax NNX API as a way to create `nnx.State` groups in APIs\n",
+ "such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. This guide will help you\n",
+ "under:\n",
+ "\n",
+ "* What is a Flax NNX `Filter`?\n",
+ "* Why are types, such as `Param` or `BatchStat`, Filters?\n",
+ "* How is `State` grouped / `Filter`ed?\n",
+ "\n",
+ "Below is an example of using `Filter`s, where `nnx.Param` and `nnx.BatchStat` are used as `Filter`s\n",
+ "to split the model into two groups - one with parameters, and one with batch statistics:"
]
},
{
@@ -59,11 +67,7 @@
"id": "8f77e99a",
"metadata": {},
"source": [
- "Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:\n",
- "\n",
- "* What is a Filter?\n",
- "* Why are types, such as `Param` or `BatchStat`, Filters?\n",
- "* How is `State` grouped / filtered?"
+ "Let’s dive deeper into `Filter`s."
]
},
{
@@ -71,20 +75,23 @@
"id": "a0413d64",
"metadata": {},
"source": [
- "## The Filter Protocol\n",
+ "## The `Filter` protocol\n",
"\n",
- "In general Filter are predicate functions of the form:\n",
+ "In general, `Filter`s are predicate functions of the form:\n",
"\n",
"```python\n",
"\n",
"(path: tuple[Key, ...], value: Any) -> bool\n",
"\n",
"```\n",
- "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n",
"\n",
- "Types are obviously not functions of this form, so the reason why they are treated as Filters \n",
- "is because, as we will see next, types and some other literals are converted to predicates. For example, \n",
- "`Param` is roughly converted to a predicate like this:"
+ "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the\n",
+ "path to the value in a nested structure, and `value` is the value at the path. The function\n",
+ "returns `True` if the value should be included in the group and `False` otherwise.\n",
+ "\n",
+ "Types are obviously not functions of this form, so the reason why they are treated as `Filter`s \n",
+ "is because, as we will see next, types and some other literals are converted to predicates.\n",
+ "For example, `Param` is roughly converted to a predicate like this:"
]
},
{
@@ -118,7 +125,7 @@
"metadata": {},
"source": [
"Such function matches any value that is an instance of `Param` or any value that has a \n",
- "`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n",
+ "`type` attribute that is a subclass of `Param`. Internally, Flax NNX uses `OfType` which\n",
"defines a callable of this form for a given type:"
]
},
@@ -149,13 +156,13 @@
"id": "87c06e39",
"metadata": {},
"source": [
- "## The Filter DSL\n",
+ "## The `Filter` DSL\n",
"\n",
- "To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized \n",
- "as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, \n",
- "tuples/lists, etc, and converts them to the appropriate predicate internally.\n",
+ "To avoid users having to create functions described in the previous section, Flax NNX exposes a\n",
+ "small DSL, formalized as the `nnx.filterlib.Filter` type, which lets users pass types, booleans,\n",
+ "ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.\n",
"\n",
- "Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n",
+ "Here is a list of all the callable `Filter`s included in Flax NNX and their DSL literals\n",
"(when available):\n",
"\n",
"\n",
@@ -163,16 +170,16 @@
"|--------|----------------------|-------------|\n",
"| `...` or `True` | `Everything()` | Matches all values |\n",
"| `None` or `False` | `Nothing()` | Matches no values |\n",
- "| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` |\n",
+ "| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attributethat is an instance of `type` |\n",
"| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` |\n",
"| `'{filter}'` str | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. |\n",
"| `(*filters)` tuple or `[*filters]` list | `Any(*filters)` | Matches values that match any of the inner `filters` |\n",
"| | `All(*filters)` | Matches values that match all of the inner `filters` |\n",
"| | `Not(filter)` | Matches values that do not match the inner `filter` |\n",
"\n",
- "Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n",
- "and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n",
- "use the following filters:"
+ "Let’s see the DSL in action using an `nnx.vmap` example. Let’s say you want to vectorize all\n",
+ "parameters, apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis, and broadcast the rest.\n",
+ "To do this, you use the following `Filter`s:"
]
},
{
@@ -194,7 +201,7 @@
"id": "bd60f0e1",
"metadata": {},
"source": [
- "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n",
+ "Here, `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n",
"expands to `Everything()`.\n",
"\n",
"If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:"
@@ -234,12 +241,13 @@
"id": "db9b4cf3",
"metadata": {},
"source": [
- "## Grouping States\n",
+ "## Grouping `State`s\n",
"\n",
- "With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:\n",
+ "With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. The following\n",
+ "are key ideas here:\n",
"\n",
"* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.\n",
- "* Convert all the filters to predicates.\n",
+ "* Convert all the `Filter`s to predicates.\n",
"* Use `State.flat_state` to get the flat representation of the state.\n",
"* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n",
"* Use `State.from_flat_state` to convert the flat states to nested `State`s."
@@ -306,12 +314,12 @@
"id": "7b3aeac8",
"metadata": {},
"source": [
- "One very important thing to note is that **filtering is order-dependent**. The first filter that \n",
- "matches a value will keep it, therefore you should place more specific filters before more general \n",
- "filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` \n",
- "object that contains both types of parameters, if we try to split the `Param`s before the \n",
- "`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group \n",
- "will be empty because all `SpecialParam`s are also `Param`s:"
+ "**Note:*** It’s important to know that _`filter`ing is order-dependent_. The first `Filter` that \n",
+ "matches a value will keep it, and therefore you should place more specific `Filter`s before more\n",
+ "general `Filter`s. For examplem if you create a `SpecialParam` type that is a subclass of `Param`,\n",
+ "and a `Bar` object that contains both types of parameters, and if you try to split the `Param`s\n",
+ "before the `SpecialParam`s, then all the values will be placed in the `Param` group and the\n",
+ "`SpecialParam` group will be empty because all `SpecialParam`s are also `Param`s:"
]
},
{
@@ -359,7 +367,7 @@
"id": "a9f0b7b8",
"metadata": {},
"source": [
- "Reversing the order will make sure that the `SpecialParam` are captured first"
+ "Reversing the order will make sure that the `SpecialParam` are captured first:"
]
},
{
@@ -388,7 +396,7 @@
}
],
"source": [
- "graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!\n",
+ "graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # Correct!\n",
"print(f'{params = }')\n",
"print(f'{special_params = }')"
]
diff --git a/docs/nnx/filters_guide.md b/docs/nnx/filters_guide.md
index 84bbe3fa7f..90d53dff12 100644
--- a/docs/nnx/filters_guide.md
+++ b/docs/nnx/filters_guide.md
@@ -8,12 +8,20 @@ jupytext:
jupytext_version: 1.13.8
---
-# Using Filters
+# Using `Filter`s
> **Attention**: This page relates to the new Flax NNX API.
-Filters are used extensively in Flax NNX as a way to create `State` groups in APIs
-such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:
+`Filter`s are used extensively in the Flax NNX API as a way to create `nnx.State` groups in APIs
+such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. This guide will help you
+under:
+
+* What is a Flax NNX `Filter`?
+* Why are types, such as `Param` or `BatchStat`, Filters?
+* How is `State` grouped / `Filter`ed?
+
+Below is an example of using `Filter`s, where `nnx.Param` and `nnx.BatchStat` are used as `Filter`s
+to split the model into two groups - one with parameters, and one with batch statistics:
```{code-cell} ipython3
from flax import nnx
@@ -31,28 +39,27 @@ print(f'{params = }')
print(f'{batch_stats = }')
```
-Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:
-
-* What is a Filter?
-* Why are types, such as `Param` or `BatchStat`, Filters?
-* How is `State` grouped / filtered?
+Let’s dive deeper into `Filter`s.
+++
-## The Filter Protocol
+## The `Filter` protocol
-In general Filter are predicate functions of the form:
+In general, `Filter`s are predicate functions of the form:
```python
(path: tuple[Key, ...], value: Any) -> bool
```
-where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.
-Types are obviously not functions of this form, so the reason why they are treated as Filters
-is because, as we will see next, types and some other literals are converted to predicates. For example,
-`Param` is roughly converted to a predicate like this:
+where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the
+path to the value in a nested structure, and `value` is the value at the path. The function
+returns `True` if the value should be included in the group and `False` otherwise.
+
+Types are obviously not functions of this form, so the reason why they are treated as `Filter`s
+is because, as we will see next, types and some other literals are converted to predicates.
+For example, `Param` is roughly converted to a predicate like this:
```{code-cell} ipython3
def is_param(path, value) -> bool:
@@ -65,7 +72,7 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
```
Such function matches any value that is an instance of `Param` or any value that has a
-`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which
+`type` attribute that is a subclass of `Param`. Internally, Flax NNX uses `OfType` which
defines a callable of this form for a given type:
```{code-cell} ipython3
@@ -75,13 +82,13 @@ print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
```
-## The Filter DSL
+## The `Filter` DSL
-To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized
-as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis,
-tuples/lists, etc, and converts them to the appropriate predicate internally.
+To avoid users having to create functions described in the previous section, Flax NNX exposes a
+small DSL, formalized as the `nnx.filterlib.Filter` type, which lets users pass types, booleans,
+ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.
-Here is a list of all the callable Filters included in Flax NNX and their DSL literals
+Here is a list of all the callable `Filter`s included in Flax NNX and their DSL literals
(when available):
@@ -89,16 +96,16 @@ Here is a list of all the callable Filters included in Flax NNX and their DSL li
|--------|----------------------|-------------|
| `...` or `True` | `Everything()` | Matches all values |
| `None` or `False` | `Nothing()` | Matches no values |
-| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` |
+| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attributethat is an instance of `type` |
| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` |
| `'{filter}'` str | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. |
| `(*filters)` tuple or `[*filters]` list | `Any(*filters)` | Matches values that match any of the inner `filters` |
| | `All(*filters)` | Matches values that match all of the inner `filters` |
| | `Not(filter)` | Matches values that do not match the inner `filter` |
-Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters
-and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can
-use the following filters:
+Let’s see the DSL in action using an `nnx.vmap` example. Let’s say you want to vectorize all
+parameters, apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis, and broadcast the rest.
+To do this, you use the following `Filter`s:
```{code-cell} ipython3
from functools import partial
@@ -108,7 +115,7 @@ def forward(model, x):
...
```
-Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`
+Here, `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`
expands to `Everything()`.
If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:
@@ -125,12 +132,13 @@ print(f'{nothing = }')
print(f'{params_or_dropout = }')
```
-## Grouping States
+## Grouping `State`s
-With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:
+With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. The following
+are key ideas here:
* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.
-* Convert all the filters to predicates.
+* Convert all the `Filter`s to predicates.
* Use `State.flat_state` to get the flat representation of the state.
* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.
* Use `State.from_flat_state` to convert the flat states to nested `State`s.
@@ -166,12 +174,12 @@ print(f'{params = }')
print(f'{batch_stats = }')
```
-One very important thing to note is that **filtering is order-dependent**. The first filter that
-matches a value will keep it, therefore you should place more specific filters before more general
-filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar`
-object that contains both types of parameters, if we try to split the `Param`s before the
-`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group
-will be empty because all `SpecialParam`s are also `Param`s:
+**Note:*** It’s important to know that _`filter`ing is order-dependent_. The first `Filter` that
+matches a value will keep it, and therefore you should place more specific `Filter`s before more
+general `Filter`s. For examplem if you create a `SpecialParam` type that is a subclass of `Param`,
+and a `Bar` object that contains both types of parameters, and if you try to split the `Param`s
+before the `SpecialParam`s, then all the values will be placed in the `Param` group and the
+`SpecialParam` group will be empty because all `SpecialParam`s are also `Param`s:
```{code-cell} ipython3
class SpecialParam(nnx.Param):
@@ -189,10 +197,10 @@ print(f'{params = }')
print(f'{special_params = }')
```
-Reversing the order will make sure that the `SpecialParam` are captured first
+Reversing the order will make sure that the `SpecialParam` are captured first:
```{code-cell} ipython3
-graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
+graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # Correct!
print(f'{params = }')
print(f'{special_params = }')
```