Skip to content

Commit

Permalink
Upgrade Flax NNX Filters doc
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Sep 16, 2024
1 parent d111adf commit cea5fb0
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 72 deletions.
80 changes: 44 additions & 36 deletions docs/nnx/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
"id": "95b08e64",
"metadata": {},
"source": [
"# Using Filters\n",
"# Using `Filter`s\n",
"\n",
"> **Attention**: This page relates to the new Flax NNX API.\n",
"\n",
"Filters are used extensively in Flax NNX as a way to create `State` groups in APIs\n",
"such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:"
"`Filter`s are used extensively in the Flax NNX API as a way to create `nnx.State` groups in APIs\n",
"such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. This guide will help you\n",
"under:\n",
"\n",
"* What is a Flax NNX `Filter`?\n",
"* Why are types, such as `Param` or `BatchStat`, Filters?\n",
"* How is `State` grouped / `Filter`ed?\n",
"\n",
"Below is an example of using `Filter`s, where `nnx.Param` and `nnx.BatchStat` are used as `Filter`s\n",
"to split the model into two groups - one with parameters, and one with batch statistics:"
]
},
{
Expand Down Expand Up @@ -59,32 +67,31 @@
"id": "8f77e99a",
"metadata": {},
"source": [
"Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:\n",
"\n",
"* What is a Filter?\n",
"* Why are types, such as `Param` or `BatchStat`, Filters?\n",
"* How is `State` grouped / filtered?"
"Let’s dive deeper into `Filter`s."
]
},
{
"cell_type": "markdown",
"id": "a0413d64",
"metadata": {},
"source": [
"## The Filter Protocol\n",
"## The `Filter` protocol\n",
"\n",
"In general Filter are predicate functions of the form:\n",
"In general, `Filter`s are predicate functions of the form:\n",
"\n",
"```python\n",
"\n",
"(path: tuple[Key, ...], value: Any) -> bool\n",
"\n",
"```\n",
"where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n",
"\n",
"Types are obviously not functions of this form, so the reason why they are treated as Filters \n",
"is because, as we will see next, types and some other literals are converted to predicates. For example, \n",
"`Param` is roughly converted to a predicate like this:"
"where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the\n",
"path to the value in a nested structure, and `value` is the value at the path. The function\n",
"returns `True` if the value should be included in the group and `False` otherwise.\n",
"\n",
"Types are obviously not functions of this form, so the reason why they are treated as `Filter`s \n",
"is because, as we will see next, types and some other literals are converted to predicates.\n",
"For example, `Param` is roughly converted to a predicate like this:"
]
},
{
Expand Down Expand Up @@ -118,7 +125,7 @@
"metadata": {},
"source": [
"Such function matches any value that is an instance of `Param` or any value that has a \n",
"`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n",
"`type` attribute that is a subclass of `Param`. Internally, Flax NNX uses `OfType` which\n",
"defines a callable of this form for a given type:"
]
},
Expand Down Expand Up @@ -149,30 +156,30 @@
"id": "87c06e39",
"metadata": {},
"source": [
"## The Filter DSL\n",
"## The `Filter` DSL\n",
"\n",
"To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized \n",
"as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, \n",
"tuples/lists, etc, and converts them to the appropriate predicate internally.\n",
"To avoid users having to create functions described in the previous section, Flax NNX exposes a\n",
"small DSL, formalized as the `nnx.filterlib.Filter` type, which lets users pass types, booleans,\n",
"ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.\n",
"\n",
"Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n",
"Here is a list of all the callable `Filter`s included in Flax NNX and their DSL literals\n",
"(when available):\n",
"\n",
"\n",
"| Literal | Callable | Description |\n",
"|--------|----------------------|-------------|\n",
"| `...` or `True` | `Everything()` | Matches all values |\n",
"| `None` or `False` | `Nothing()` | Matches no values |\n",
"| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` |\n",
"| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attributethat is an instance of `type` |\n",
"| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` |\n",
"| `'{filter}'` <span style=\"color:gray\">str</span> | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. |\n",
"| `(*filters)` <span style=\"color:gray\">tuple</span> or `[*filters]` <span style=\"color:gray\">list</span> | `Any(*filters)` | Matches values that match any of the inner `filters` |\n",
"| | `All(*filters)` | Matches values that match all of the inner `filters` |\n",
"| | `Not(filter)` | Matches values that do not match the inner `filter` |\n",
"\n",
"Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n",
"and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n",
"use the following filters:"
"Let’s see the DSL in action using an `nnx.vmap` example. Let’s say you want to vectorize all\n",
"parameters, apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis, and broadcast the rest.\n",
"To do this, you use the following `Filter`s:"
]
},
{
Expand All @@ -194,7 +201,7 @@
"id": "bd60f0e1",
"metadata": {},
"source": [
"Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n",
"Here, `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n",
"expands to `Everything()`.\n",
"\n",
"If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:"
Expand Down Expand Up @@ -234,12 +241,13 @@
"id": "db9b4cf3",
"metadata": {},
"source": [
"## Grouping States\n",
"## Grouping `State`s\n",
"\n",
"With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:\n",
"With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. The following\n",
"are key ideas here:\n",
"\n",
"* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.\n",
"* Convert all the filters to predicates.\n",
"* Convert all the `Filter`s to predicates.\n",
"* Use `State.flat_state` to get the flat representation of the state.\n",
"* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n",
"* Use `State.from_flat_state` to convert the flat states to nested `State`s."
Expand Down Expand Up @@ -306,12 +314,12 @@
"id": "7b3aeac8",
"metadata": {},
"source": [
"One very important thing to note is that **filtering is order-dependent**. The first filter that \n",
"matches a value will keep it, therefore you should place more specific filters before more general \n",
"filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` \n",
"object that contains both types of parameters, if we try to split the `Param`s before the \n",
"`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group \n",
"will be empty because all `SpecialParam`s are also `Param`s:"
"**Note:*** It’s important to know that _`filter`ing is order-dependent_. The first `Filter` that \n",
"matches a value will keep it, and therefore you should place more specific `Filter`s before more\n",
"general `Filter`s. For examplem if you create a `SpecialParam` type that is a subclass of `Param`,\n",
"and a `Bar` object that contains both types of parameters, and if you try to split the `Param`s\n",
"before the `SpecialParam`s, then all the values will be placed in the `Param` group and the\n",
"`SpecialParam` group will be empty because all `SpecialParam`s are also `Param`s:"
]
},
{
Expand Down Expand Up @@ -359,7 +367,7 @@
"id": "a9f0b7b8",
"metadata": {},
"source": [
"Reversing the order will make sure that the `SpecialParam` are captured first"
"Reversing the order will make sure that the `SpecialParam` are captured first:"
]
},
{
Expand Down Expand Up @@ -388,7 +396,7 @@
}
],
"source": [
"graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!\n",
"graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # Correct!\n",
"print(f'{params = }')\n",
"print(f'{special_params = }')"
]
Expand Down
80 changes: 44 additions & 36 deletions docs/nnx/filters_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@ jupytext:
jupytext_version: 1.13.8
---

# Using Filters
# Using `Filter`s

> **Attention**: This page relates to the new Flax NNX API.
Filters are used extensively in Flax NNX as a way to create `State` groups in APIs
such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:
`Filter`s are used extensively in the Flax NNX API as a way to create `nnx.State` groups in APIs
such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. This guide will help you
under:

* What is a Flax NNX `Filter`?
* Why are types, such as `Param` or `BatchStat`, Filters?
* How is `State` grouped / `Filter`ed?

Below is an example of using `Filter`s, where `nnx.Param` and `nnx.BatchStat` are used as `Filter`s
to split the model into two groups - one with parameters, and one with batch statistics:

```{code-cell} ipython3
from flax import nnx
Expand All @@ -31,28 +39,27 @@ print(f'{params = }')
print(f'{batch_stats = }')
```

Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:

* What is a Filter?
* Why are types, such as `Param` or `BatchStat`, Filters?
* How is `State` grouped / filtered?
Let’s dive deeper into `Filter`s.

+++

## The Filter Protocol
## The `Filter` protocol

In general Filter are predicate functions of the form:
In general, `Filter`s are predicate functions of the form:

```python

(path: tuple[Key, ...], value: Any) -> bool

```
where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.

Types are obviously not functions of this form, so the reason why they are treated as Filters
is because, as we will see next, types and some other literals are converted to predicates. For example,
`Param` is roughly converted to a predicate like this:
where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the
path to the value in a nested structure, and `value` is the value at the path. The function
returns `True` if the value should be included in the group and `False` otherwise.

Types are obviously not functions of this form, so the reason why they are treated as `Filter`s
is because, as we will see next, types and some other literals are converted to predicates.
For example, `Param` is roughly converted to a predicate like this:

```{code-cell} ipython3
def is_param(path, value) -> bool:
Expand All @@ -65,7 +72,7 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
```

Such function matches any value that is an instance of `Param` or any value that has a
`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which
`type` attribute that is a subclass of `Param`. Internally, Flax NNX uses `OfType` which
defines a callable of this form for a given type:

```{code-cell} ipython3
Expand All @@ -75,30 +82,30 @@ print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
```

## The Filter DSL
## The `Filter` DSL

To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized
as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis,
tuples/lists, etc, and converts them to the appropriate predicate internally.
To avoid users having to create functions described in the previous section, Flax NNX exposes a
small DSL, formalized as the `nnx.filterlib.Filter` type, which lets users pass types, booleans,
ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.

Here is a list of all the callable Filters included in Flax NNX and their DSL literals
Here is a list of all the callable `Filter`s included in Flax NNX and their DSL literals
(when available):


| Literal | Callable | Description |
|--------|----------------------|-------------|
| `...` or `True` | `Everything()` | Matches all values |
| `None` or `False` | `Nothing()` | Matches no values |
| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` |
| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attributethat is an instance of `type` |
| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` |
| `'{filter}'` <span style="color:gray">str</span> | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. |
| `(*filters)` <span style="color:gray">tuple</span> or `[*filters]` <span style="color:gray">list</span> | `Any(*filters)` | Matches values that match any of the inner `filters` |
| | `All(*filters)` | Matches values that match all of the inner `filters` |
| | `Not(filter)` | Matches values that do not match the inner `filter` |

Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters
and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can
use the following filters:
Let’s see the DSL in action using an `nnx.vmap` example. Let’s say you want to vectorize all
parameters, apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis, and broadcast the rest.
To do this, you use the following `Filter`s:

```{code-cell} ipython3
from functools import partial
Expand All @@ -108,7 +115,7 @@ def forward(model, x):
...
```

Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`
Here, `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`
expands to `Everything()`.

If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:
Expand All @@ -125,12 +132,13 @@ print(f'{nothing = }')
print(f'{params_or_dropout = }')
```

## Grouping States
## Grouping `State`s

With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:
With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. The following
are key ideas here:

* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.
* Convert all the filters to predicates.
* Convert all the `Filter`s to predicates.
* Use `State.flat_state` to get the flat representation of the state.
* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.
* Use `State.from_flat_state` to convert the flat states to nested `State`s.
Expand Down Expand Up @@ -166,12 +174,12 @@ print(f'{params = }')
print(f'{batch_stats = }')
```

One very important thing to note is that **filtering is order-dependent**. The first filter that
matches a value will keep it, therefore you should place more specific filters before more general
filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar`
object that contains both types of parameters, if we try to split the `Param`s before the
`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group
will be empty because all `SpecialParam`s are also `Param`s:
**Note:*** It’s important to know that _`filter`ing is order-dependent_. The first `Filter` that
matches a value will keep it, and therefore you should place more specific `Filter`s before more
general `Filter`s. For examplem if you create a `SpecialParam` type that is a subclass of `Param`,
and a `Bar` object that contains both types of parameters, and if you try to split the `Param`s
before the `SpecialParam`s, then all the values will be placed in the `Param` group and the
`SpecialParam` group will be empty because all `SpecialParam`s are also `Param`s:

```{code-cell} ipython3
class SpecialParam(nnx.Param):
Expand All @@ -189,10 +197,10 @@ print(f'{params = }')
print(f'{special_params = }')
```

Reversing the order will make sure that the `SpecialParam` are captured first
Reversing the order will make sure that the `SpecialParam` are captured first:

```{code-cell} ipython3
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # Correct!
print(f'{params = }')
print(f'{special_params = }')
```

0 comments on commit cea5fb0

Please sign in to comment.