From cea5fb018a0b34a6d488d852a68e0059d8940dfc Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 16 Sep 2024 21:12:18 +0000 Subject: [PATCH] Upgrade Flax NNX Filters doc --- docs/nnx/filters_guide.ipynb | 80 ++++++++++++++++++++---------------- docs/nnx/filters_guide.md | 80 ++++++++++++++++++++---------------- 2 files changed, 88 insertions(+), 72 deletions(-) diff --git a/docs/nnx/filters_guide.ipynb b/docs/nnx/filters_guide.ipynb index 21591226ac..bd2d9bbb4c 100644 --- a/docs/nnx/filters_guide.ipynb +++ b/docs/nnx/filters_guide.ipynb @@ -5,12 +5,20 @@ "id": "95b08e64", "metadata": {}, "source": [ - "# Using Filters\n", + "# Using `Filter`s\n", "\n", "> **Attention**: This page relates to the new Flax NNX API.\n", "\n", - "Filters are used extensively in Flax NNX as a way to create `State` groups in APIs\n", - "such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:" + "`Filter`s are used extensively in the Flax NNX API as a way to create `nnx.State` groups in APIs\n", + "such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. This guide will help you\n", + "under:\n", + "\n", + "* What is a Flax NNX `Filter`?\n", + "* Why are types, such as `Param` or `BatchStat`, Filters?\n", + "* How is `State` grouped / `Filter`ed?\n", + "\n", + "Below is an example of using `Filter`s, where `nnx.Param` and `nnx.BatchStat` are used as `Filter`s\n", + "to split the model into two groups - one with parameters, and one with batch statistics:" ] }, { @@ -59,11 +67,7 @@ "id": "8f77e99a", "metadata": {}, "source": [ - "Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:\n", - "\n", - "* What is a Filter?\n", - "* Why are types, such as `Param` or `BatchStat`, Filters?\n", - "* How is `State` grouped / filtered?" + "Let’s dive deeper into `Filter`s." ] }, { @@ -71,20 +75,23 @@ "id": "a0413d64", "metadata": {}, "source": [ - "## The Filter Protocol\n", + "## The `Filter` protocol\n", "\n", - "In general Filter are predicate functions of the form:\n", + "In general, `Filter`s are predicate functions of the form:\n", "\n", "```python\n", "\n", "(path: tuple[Key, ...], value: Any) -> bool\n", "\n", "```\n", - "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n", "\n", - "Types are obviously not functions of this form, so the reason why they are treated as Filters \n", - "is because, as we will see next, types and some other literals are converted to predicates. For example, \n", - "`Param` is roughly converted to a predicate like this:" + "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the\n", + "path to the value in a nested structure, and `value` is the value at the path. The function\n", + "returns `True` if the value should be included in the group and `False` otherwise.\n", + "\n", + "Types are obviously not functions of this form, so the reason why they are treated as `Filter`s \n", + "is because, as we will see next, types and some other literals are converted to predicates.\n", + "For example, `Param` is roughly converted to a predicate like this:" ] }, { @@ -118,7 +125,7 @@ "metadata": {}, "source": [ "Such function matches any value that is an instance of `Param` or any value that has a \n", - "`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n", + "`type` attribute that is a subclass of `Param`. Internally, Flax NNX uses `OfType` which\n", "defines a callable of this form for a given type:" ] }, @@ -149,13 +156,13 @@ "id": "87c06e39", "metadata": {}, "source": [ - "## The Filter DSL\n", + "## The `Filter` DSL\n", "\n", - "To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized \n", - "as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, \n", - "tuples/lists, etc, and converts them to the appropriate predicate internally.\n", + "To avoid users having to create functions described in the previous section, Flax NNX exposes a\n", + "small DSL, formalized as the `nnx.filterlib.Filter` type, which lets users pass types, booleans,\n", + "ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.\n", "\n", - "Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n", + "Here is a list of all the callable `Filter`s included in Flax NNX and their DSL literals\n", "(when available):\n", "\n", "\n", @@ -163,16 +170,16 @@ "|--------|----------------------|-------------|\n", "| `...` or `True` | `Everything()` | Matches all values |\n", "| `None` or `False` | `Nothing()` | Matches no values |\n", - "| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` |\n", + "| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attributethat is an instance of `type` |\n", "| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` |\n", "| `'{filter}'` str | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. |\n", "| `(*filters)` tuple or `[*filters]` list | `Any(*filters)` | Matches values that match any of the inner `filters` |\n", "| | `All(*filters)` | Matches values that match all of the inner `filters` |\n", "| | `Not(filter)` | Matches values that do not match the inner `filter` |\n", "\n", - "Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n", - "and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n", - "use the following filters:" + "Let’s see the DSL in action using an `nnx.vmap` example. Let’s say you want to vectorize all\n", + "parameters, apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis, and broadcast the rest.\n", + "To do this, you use the following `Filter`s:" ] }, { @@ -194,7 +201,7 @@ "id": "bd60f0e1", "metadata": {}, "source": [ - "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n", + "Here, `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n", "expands to `Everything()`.\n", "\n", "If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:" @@ -234,12 +241,13 @@ "id": "db9b4cf3", "metadata": {}, "source": [ - "## Grouping States\n", + "## Grouping `State`s\n", "\n", - "With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:\n", + "With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. The following\n", + "are key ideas here:\n", "\n", "* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.\n", - "* Convert all the filters to predicates.\n", + "* Convert all the `Filter`s to predicates.\n", "* Use `State.flat_state` to get the flat representation of the state.\n", "* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n", "* Use `State.from_flat_state` to convert the flat states to nested `State`s." @@ -306,12 +314,12 @@ "id": "7b3aeac8", "metadata": {}, "source": [ - "One very important thing to note is that **filtering is order-dependent**. The first filter that \n", - "matches a value will keep it, therefore you should place more specific filters before more general \n", - "filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` \n", - "object that contains both types of parameters, if we try to split the `Param`s before the \n", - "`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group \n", - "will be empty because all `SpecialParam`s are also `Param`s:" + "**Note:*** It’s important to know that _`filter`ing is order-dependent_. The first `Filter` that \n", + "matches a value will keep it, and therefore you should place more specific `Filter`s before more\n", + "general `Filter`s. For examplem if you create a `SpecialParam` type that is a subclass of `Param`,\n", + "and a `Bar` object that contains both types of parameters, and if you try to split the `Param`s\n", + "before the `SpecialParam`s, then all the values will be placed in the `Param` group and the\n", + "`SpecialParam` group will be empty because all `SpecialParam`s are also `Param`s:" ] }, { @@ -359,7 +367,7 @@ "id": "a9f0b7b8", "metadata": {}, "source": [ - "Reversing the order will make sure that the `SpecialParam` are captured first" + "Reversing the order will make sure that the `SpecialParam` are captured first:" ] }, { @@ -388,7 +396,7 @@ } ], "source": [ - "graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!\n", + "graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # Correct!\n", "print(f'{params = }')\n", "print(f'{special_params = }')" ] diff --git a/docs/nnx/filters_guide.md b/docs/nnx/filters_guide.md index 84bbe3fa7f..90d53dff12 100644 --- a/docs/nnx/filters_guide.md +++ b/docs/nnx/filters_guide.md @@ -8,12 +8,20 @@ jupytext: jupytext_version: 1.13.8 --- -# Using Filters +# Using `Filter`s > **Attention**: This page relates to the new Flax NNX API. -Filters are used extensively in Flax NNX as a way to create `State` groups in APIs -such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example: +`Filter`s are used extensively in the Flax NNX API as a way to create `nnx.State` groups in APIs +such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. This guide will help you +under: + +* What is a Flax NNX `Filter`? +* Why are types, such as `Param` or `BatchStat`, Filters? +* How is `State` grouped / `Filter`ed? + +Below is an example of using `Filter`s, where `nnx.Param` and `nnx.BatchStat` are used as `Filter`s +to split the model into two groups - one with parameters, and one with batch statistics: ```{code-cell} ipython3 from flax import nnx @@ -31,28 +39,27 @@ print(f'{params = }') print(f'{batch_stats = }') ``` -Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions: - -* What is a Filter? -* Why are types, such as `Param` or `BatchStat`, Filters? -* How is `State` grouped / filtered? +Let’s dive deeper into `Filter`s. +++ -## The Filter Protocol +## The `Filter` protocol -In general Filter are predicate functions of the form: +In general, `Filter`s are predicate functions of the form: ```python (path: tuple[Key, ...], value: Any) -> bool ``` -where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise. -Types are obviously not functions of this form, so the reason why they are treated as Filters -is because, as we will see next, types and some other literals are converted to predicates. For example, -`Param` is roughly converted to a predicate like this: +where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the +path to the value in a nested structure, and `value` is the value at the path. The function +returns `True` if the value should be included in the group and `False` otherwise. + +Types are obviously not functions of this form, so the reason why they are treated as `Filter`s +is because, as we will see next, types and some other literals are converted to predicates. +For example, `Param` is roughly converted to a predicate like this: ```{code-cell} ipython3 def is_param(path, value) -> bool: @@ -65,7 +72,7 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` Such function matches any value that is an instance of `Param` or any value that has a -`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which +`type` attribute that is a subclass of `Param`. Internally, Flax NNX uses `OfType` which defines a callable of this form for a given type: ```{code-cell} ipython3 @@ -75,13 +82,13 @@ print(f'{is_param((), nnx.Param(0)) = }') print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` -## The Filter DSL +## The `Filter` DSL -To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized -as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, -tuples/lists, etc, and converts them to the appropriate predicate internally. +To avoid users having to create functions described in the previous section, Flax NNX exposes a +small DSL, formalized as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, +ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally. -Here is a list of all the callable Filters included in Flax NNX and their DSL literals +Here is a list of all the callable `Filter`s included in Flax NNX and their DSL literals (when available): @@ -89,16 +96,16 @@ Here is a list of all the callable Filters included in Flax NNX and their DSL li |--------|----------------------|-------------| | `...` or `True` | `Everything()` | Matches all values | | `None` or `False` | `Nothing()` | Matches no values | -| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` | +| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attributethat is an instance of `type` | | | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` | | `'{filter}'` str | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. | | `(*filters)` tuple or `[*filters]` list | `Any(*filters)` | Matches values that match any of the inner `filters` | | | `All(*filters)` | Matches values that match all of the inner `filters` | | | `Not(filter)` | Matches values that do not match the inner `filter` | -Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters -and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can -use the following filters: +Let’s see the DSL in action using an `nnx.vmap` example. Let’s say you want to vectorize all +parameters, apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis, and broadcast the rest. +To do this, you use the following `Filter`s: ```{code-cell} ipython3 from functools import partial @@ -108,7 +115,7 @@ def forward(model, x): ... ``` -Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` +Here, `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`. If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`: @@ -125,12 +132,13 @@ print(f'{nothing = }') print(f'{params_or_dropout = }') ``` -## Grouping States +## Grouping `State`s -With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas: +With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. The following +are key ideas here: * Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node. -* Convert all the filters to predicates. +* Convert all the `Filter`s to predicates. * Use `State.flat_state` to get the flat representation of the state. * Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates. * Use `State.from_flat_state` to convert the flat states to nested `State`s. @@ -166,12 +174,12 @@ print(f'{params = }') print(f'{batch_stats = }') ``` -One very important thing to note is that **filtering is order-dependent**. The first filter that -matches a value will keep it, therefore you should place more specific filters before more general -filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` -object that contains both types of parameters, if we try to split the `Param`s before the -`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group -will be empty because all `SpecialParam`s are also `Param`s: +**Note:*** It’s important to know that _`filter`ing is order-dependent_. The first `Filter` that +matches a value will keep it, and therefore you should place more specific `Filter`s before more +general `Filter`s. For examplem if you create a `SpecialParam` type that is a subclass of `Param`, +and a `Bar` object that contains both types of parameters, and if you try to split the `Param`s +before the `SpecialParam`s, then all the values will be placed in the `Param` group and the +`SpecialParam` group will be empty because all `SpecialParam`s are also `Param`s: ```{code-cell} ipython3 class SpecialParam(nnx.Param): @@ -189,10 +197,10 @@ print(f'{params = }') print(f'{special_params = }') ``` -Reversing the order will make sure that the `SpecialParam` are captured first +Reversing the order will make sure that the `SpecialParam` are captured first: ```{code-cell} ipython3 -graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct! +graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # Correct! print(f'{params = }') print(f'{special_params = }') ```