Skip to content

Commit

Permalink
Upgrade NNX Filters guide
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Oct 6, 2024
1 parent 5d31452 commit a73e710
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 96 deletions.
101 changes: 54 additions & 47 deletions docs_nnx/guides/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
"id": "95b08e64",
"metadata": {},
"source": [
"# Using Filters\n",
"# Using `Filter`s\n",
"\n",
"> **Attention**: This page relates to the new Flax NNX API.\n",
"Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n",
"\n",
"Filters are used extensively in Flax NNX as a way to create `State` groups in APIs\n",
"such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:"
"In this guide you will learn:\n",
"\n",
"* What is a `Filter`?\n",
"* Why are types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), treated as `Filter`s?\n",
"* What is the `Filter` domain specific language (DSL)?\n",
"* How is [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) grouped / filtered?\n",
"\n",
"In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics:"
]
},
{
Expand Down Expand Up @@ -59,32 +65,33 @@
"id": "8f77e99a",
"metadata": {},
"source": [
"Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:\n",
"\n",
"* What is a Filter?\n",
"* Why are types, such as `Param` or `BatchStat`, Filters?\n",
"* How is `State` grouped / filtered?"
"Let's dive deeper into `Filter`s."
]
},
{
"cell_type": "markdown",
"id": "a0413d64",
"metadata": {},
"source": [
"## The Filter Protocol\n",
"## The `Filter` Protocol\n",
"\n",
"In general Filter are predicate functions of the form:\n",
"In general, Flax `Filter`s are predicate functions of the form:\n",
"\n",
"```python\n",
"\n",
"(path: tuple[Key, ...], value: Any) -> bool\n",
"\n",
"```\n",
"where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n",
"\n",
"Types are obviously not functions of this form, so the reason why they are treated as Filters\n",
"is because, as we will see next, types and some other literals are converted to predicates. For example,\n",
"`Param` is roughly converted to a predicate like this:"
"where:\n",
"\n",
"- `Key` is a hashable and comparable type;\n",
"- `path` is a tuple of `Key`s representing the path to the value in a nested structure; and\n",
"- `value` is the value at the path.\n",
"\n",
"The function returns `True` if the value should be included in the group, and `False` otherwise.\n",
"\n",
"Types are not functions of this form. They are treated as `Filter`s because, as you will learn in the next section, types and some other literals are converted to _predicates_. For example, [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) is roughly converted to a predicate like this:"
]
},
{
Expand Down Expand Up @@ -117,9 +124,7 @@
"id": "a8a2641e",
"metadata": {},
"source": [
"Such function matches any value that is an instance of `Param` or any value that has a\n",
"`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n",
"defines a callable of this form for a given type:"
"Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:"
]
},
{
Expand Down Expand Up @@ -149,14 +154,11 @@
"id": "87c06e39",
"metadata": {},
"source": [
"## The Filter DSL\n",
"## The `Filter` DSL\n",
"\n",
"To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized\n",
"as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis,\n",
"tuples/lists, etc, and converts them to the appropriate predicate internally.\n",
"To help users avoid having to create functions mentioned in the previous section, Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. The `Filter` DSL allows users to pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.\n",
"\n",
"Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n",
"(when available):\n",
"Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):\n",
"\n",
"\n",
"| Literal | Callable | Description |\n",
Expand All @@ -170,10 +172,14 @@
"| | `All(*filters)` | Matches values that match all of the inner `filters` |\n",
"| | `Not(filter)` | Matches values that do not match the inner `filter` |\n",
"\n",
"Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n",
"and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n",
"use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`\n",
"to specify how `model`'s various substates should be vectorized:"
"\n",
"Let's check out the DSL in action by using [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) as an example. Consider the following:\n",
"\n",
"1) You want to vectorize all parameters;\n",
"2) Apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis; and\n",
"3) Broadcast the rest.\n",
"\n",
"To do this, you can use the following `Filter`s to define a `nnx.StateAxes` object that you can pass to `nnx.vmap`'s `in_axes to specify how `model`'s various substates should be vectorized:"
]
},
{
Expand All @@ -183,9 +189,9 @@
"metadata": {},
"outputs": [],
"source": [
"state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})\n",
"from functools import partial\n",
"\n",
"@nnx.vmap(in_axes=(state_axes, 0))\n",
"@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})\n",
"def forward(model, x):\n",
" ..."
]
Expand All @@ -195,10 +201,9 @@
"id": "bd60f0e1",
"metadata": {},
"source": [
"Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n",
"expands to `Everything()`.\n",
"Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`.\n",
"\n",
"If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:"
"If you wish to manually convert literal into a predicate, you can use [`nnx.filterlib.to_predicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html#flax.nnx.filterlib.to_predicate):"
]
},
{
Expand Down Expand Up @@ -235,15 +240,15 @@
"id": "db9b4cf3",
"metadata": {},
"source": [
"## Grouping States\n",
"## Grouping `State`s\n",
"\n",
"With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:\n",
"With the knowledge of `Filter`s from previous sections at hand, let's learn how to roughly implement [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). Here are the key ideas:\n",
"\n",
"* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.\n",
"* Convert all the filters to predicates.\n",
"* Use `nnx.graph.flatten` to get the [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) and [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) representation of the node.\n",
"* Convert all the `Filter`s to predicates.\n",
"* Use `State.flat_state` to get the flat representation of the state.\n",
"* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n",
"* Use `State.from_flat_state` to convert the flat states to nested `State`s."
"* Use `State.from_flat_state` to convert the flat states to nested [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s."
]
},
{
Expand Down Expand Up @@ -276,7 +281,7 @@
"KeyPath = tuple[nnx.graph.Key, ...]\n",
"\n",
"def split(node, *filters):\n",
" graphdef, state = nnx.graph.flatten(node)\n",
" graphdef, state, _ = nnx.graph.flatten(node)\n",
" predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
" flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
"\n",
Expand All @@ -293,7 +298,7 @@
" )\n",
" return graphdef, *states\n",
"\n",
"# lets test it...\n",
"# Let's test it.\n",
"foo = Foo()\n",
"\n",
"graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)\n",
Expand All @@ -307,12 +312,14 @@
"id": "7b3aeac8",
"metadata": {},
"source": [
"One very important thing to note is that **filtering is order-dependent**. The first filter that\n",
"matches a value will keep it, therefore you should place more specific filters before more general\n",
"filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar`\n",
"object that contains both types of parameters, if we try to split the `Param`s before the\n",
"`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group\n",
"will be empty because all `SpecialParam`s are also `Param`s:"
"**Note:*** It's very important to know that **filtering is order-dependent**. The first `Filter` that matches a value will keep it, and therefore you should place more specific `Filter`s before more general `Filter`s.\n",
"\n",
"For example, as demonstrated below, if you:\n",
"\n",
"1) Create a `SpecialParam` type that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param), and a `Bar` object (subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)) that contains both types of parameters; and\n",
"2) Try to split the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s before the `SpecialParam`s\n",
"\n",
"then all the values will be placed in the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) group, and the `SpecialParam` group will be empty because all `SpecialParam`s are also [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s:"
]
},
{
Expand Down Expand Up @@ -360,7 +367,7 @@
"id": "a9f0b7b8",
"metadata": {},
"source": [
"Reversing the order will make sure that the `SpecialParam` are captured first"
"And reversing the order will ensure that the `SpecialParam` are captured first:"
]
},
{
Expand Down
Loading

0 comments on commit a73e710

Please sign in to comment.