From e5586f465412d2ecba81eeb66f75bc53228cb932 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Sat, 5 Oct 2024 00:37:41 +0000 Subject: [PATCH] Upgrade NNX Filters guide --- docs_nnx/guides/filters_guide.ipynb | 97 ++++++++++++++++------------- docs_nnx/guides/filters_guide.md | 97 ++++++++++++++++------------- docs_nnx/nnx_basics.ipynb | 2 +- docs_nnx/nnx_basics.md | 2 +- 4 files changed, 106 insertions(+), 92 deletions(-) diff --git a/docs_nnx/guides/filters_guide.ipynb b/docs_nnx/guides/filters_guide.ipynb index ed37ad8731..57b7cf58ed 100644 --- a/docs_nnx/guides/filters_guide.ipynb +++ b/docs_nnx/guides/filters_guide.ipynb @@ -5,12 +5,18 @@ "id": "95b08e64", "metadata": {}, "source": [ - "# Using Filters\n", + "# Using `Filter`s\n", "\n", - "> **Attention**: This page relates to the new Flax NNX API.\n", + "Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n", "\n", - "Filters are used extensively in Flax NNX as a way to create `State` groups in APIs\n", - "such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:" + "In this guide you will learn:\n", + "\n", + "* What is a `Filter`?\n", + "* Why are types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), treated as `Filter`s?\n", + "* What is the `Filter` domain specific language (DSL)?\n", + "* How is [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) grouped / filtered?\n", + "\n", + "In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics:" ] }, { @@ -59,11 +65,7 @@ "id": "8f77e99a", "metadata": {}, "source": [ - "Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:\n", - "\n", - "* What is a Filter?\n", - "* Why are types, such as `Param` or `BatchStat`, Filters?\n", - "* How is `State` grouped / filtered?" + "Let's dive deeper into `Filter`s." ] }, { @@ -71,20 +73,25 @@ "id": "a0413d64", "metadata": {}, "source": [ - "## The Filter Protocol\n", + "## The `Filter` Protocol\n", "\n", - "In general Filter are predicate functions of the form:\n", + "In general, Flax `Filter`s are predicate functions of the form:\n", "\n", "```python\n", "\n", "(path: tuple[Key, ...], value: Any) -> bool\n", "\n", "```\n", - "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n", "\n", - "Types are obviously not functions of this form, so the reason why they are treated as Filters\n", - "is because, as we will see next, types and some other literals are converted to predicates. For example,\n", - "`Param` is roughly converted to a predicate like this:" + "where:\n", + "\n", + "- `Key` is a hashable and comparable type;\n", + "- `path` is a tuple of `Key`s representing the path to the value in a nested structure; and\n", + "- `value` is the value at the path.\n", + "\n", + "The function returns `True` if the value should be included in the group, and `False` otherwise.\n", + "\n", + "Types are not functions of this form. They are treated as `Filter`s because, as you will learn in the next section, types and some other literals are converted to _predicates_. For example, [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) is roughly converted to a predicate like this:" ] }, { @@ -117,9 +124,7 @@ "id": "a8a2641e", "metadata": {}, "source": [ - "Such function matches any value that is an instance of `Param` or any value that has a\n", - "`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n", - "defines a callable of this form for a given type:" + "Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:" ] }, { @@ -149,14 +154,11 @@ "id": "87c06e39", "metadata": {}, "source": [ - "## The Filter DSL\n", + "## The `Filter` DSL\n", "\n", - "To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized\n", - "as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis,\n", - "tuples/lists, etc, and converts them to the appropriate predicate internally.\n", + "To help users avoid having to create functions mentioned in the previous section, Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. The `Filter` DSL allows users to pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.\n", "\n", - "Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n", - "(when available):\n", + "Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):\n", "\n", "\n", "| Literal | Callable | Description |\n", @@ -170,10 +172,14 @@ "| | `All(*filters)` | Matches values that match all of the inner `filters` |\n", "| | `Not(filter)` | Matches values that do not match the inner `filter` |\n", "\n", - "Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n", - "and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n", - "use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`\n", - "to specify how `model`'s various substates should be vectorized:" + "\n", + "Let's check out the DSL in action by using [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) as an example. Consider the following:\n", + "\n", + "1) You want to vectorize all parameters;\n", + "2) Apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis; and\n", + "3) Broadcast the rest.\n", + "\n", + "To do this, you can use the following `Filter`s to define a `nnx.StateAxes` object that you can pass to `nnx.vmap`'s `in_axes` to specify how the `model`'s various sub-states should be vectorized:" ] }, { @@ -195,10 +201,9 @@ "id": "bd60f0e1", "metadata": {}, "source": [ - "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n", - "expands to `Everything()`.\n", + "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`.\n", "\n", - "If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:" + "If you wish to manually convert literal into a predicate, you can use [`nnx.filterlib.to_predicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html#flax.nnx.filterlib.to_predicate):" ] }, { @@ -235,15 +240,15 @@ "id": "db9b4cf3", "metadata": {}, "source": [ - "## Grouping States\n", + "## Grouping `State`s\n", "\n", - "With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:\n", + "With the knowledge of `Filter`s from previous sections at hand, let's learn how to roughly implement [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). Here are the key ideas:\n", "\n", - "* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.\n", - "* Convert all the filters to predicates.\n", + "* Use `nnx.graph.flatten` to get the [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) and [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) representation of the node.\n", + "* Convert all the `Filter`s to predicates.\n", "* Use `State.flat_state` to get the flat representation of the state.\n", "* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n", - "* Use `State.from_flat_state` to convert the flat states to nested `State`s." + "* Use `State.from_flat_state` to convert the flat states to nested [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s." ] }, { @@ -276,7 +281,7 @@ "KeyPath = tuple[nnx.graph.Key, ...]\n", "\n", "def split(node, *filters):\n", - " graphdef, state = nnx.graph.flatten(node)\n", + " graphdef, state, _ = nnx.graph.flatten(node)\n", " predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n", " flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n", "\n", @@ -293,7 +298,7 @@ " )\n", " return graphdef, *states\n", "\n", - "# lets test it...\n", + "# Let's test it.\n", "foo = Foo()\n", "\n", "graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)\n", @@ -307,12 +312,14 @@ "id": "7b3aeac8", "metadata": {}, "source": [ - "One very important thing to note is that **filtering is order-dependent**. The first filter that\n", - "matches a value will keep it, therefore you should place more specific filters before more general\n", - "filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar`\n", - "object that contains both types of parameters, if we try to split the `Param`s before the\n", - "`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group\n", - "will be empty because all `SpecialParam`s are also `Param`s:" + "**Note:*** It's very important to know that **filtering is order-dependent**. The first `Filter` that matches a value will keep it, and therefore you should place more specific `Filter`s before more general `Filter`s.\n", + "\n", + "For example, as demonstrated below, if you:\n", + "\n", + "1) Create a `SpecialParam` type that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param), and a `Bar` object (subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)) that contains both types of parameters; and\n", + "2) Try to split the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s before the `SpecialParam`s\n", + "\n", + "then all the values will be placed in the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) group, and the `SpecialParam` group will be empty because all `SpecialParam`s are also [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s:" ] }, { @@ -360,7 +367,7 @@ "id": "a9f0b7b8", "metadata": {}, "source": [ - "Reversing the order will make sure that the `SpecialParam` are captured first" + "And reversing the order will ensure that the `SpecialParam` are captured first:" ] }, { diff --git a/docs_nnx/guides/filters_guide.md b/docs_nnx/guides/filters_guide.md index 97ff439ce2..31e3b96c3f 100644 --- a/docs_nnx/guides/filters_guide.md +++ b/docs_nnx/guides/filters_guide.md @@ -8,12 +8,18 @@ jupytext: jupytext_version: 1.13.8 --- -# Using Filters +# Using `Filter`s -> **Attention**: This page relates to the new Flax NNX API. +Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html). -Filters are used extensively in Flax NNX as a way to create `State` groups in APIs -such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example: +In this guide you will learn: + +* What is a `Filter`? +* Why are types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), treated as `Filter`s? +* What is the `Filter` domain specific language (DSL)? +* How is [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) grouped / filtered? + +In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics: ```{code-cell} ipython3 from flax import nnx @@ -31,28 +37,29 @@ print(f'{params = }') print(f'{batch_stats = }') ``` -Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions: - -* What is a Filter? -* Why are types, such as `Param` or `BatchStat`, Filters? -* How is `State` grouped / filtered? +Let's dive deeper into `Filter`s. +++ -## The Filter Protocol +## The `Filter` Protocol -In general Filter are predicate functions of the form: +In general, Flax `Filter`s are predicate functions of the form: ```python (path: tuple[Key, ...], value: Any) -> bool ``` -where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise. -Types are obviously not functions of this form, so the reason why they are treated as Filters -is because, as we will see next, types and some other literals are converted to predicates. For example, -`Param` is roughly converted to a predicate like this: +where: + +- `Key` is a hashable and comparable type; +- `path` is a tuple of `Key`s representing the path to the value in a nested structure; and +- `value` is the value at the path. + +The function returns `True` if the value should be included in the group, and `False` otherwise. + +Types are not functions of this form. They are treated as `Filter`s because, as you will learn in the next section, types and some other literals are converted to _predicates_. For example, [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) is roughly converted to a predicate like this: ```{code-cell} ipython3 def is_param(path, value) -> bool: @@ -64,9 +71,7 @@ print(f'{is_param((), nnx.Param(0)) = }') print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` -Such function matches any value that is an instance of `Param` or any value that has a -`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which -defines a callable of this form for a given type: +Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type: ```{code-cell} ipython3 is_param = nnx.OfType(nnx.Param) @@ -75,14 +80,11 @@ print(f'{is_param((), nnx.Param(0)) = }') print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` -## The Filter DSL +## The `Filter` DSL -To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized -as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, -tuples/lists, etc, and converts them to the appropriate predicate internally. +To help users avoid having to create functions mentioned in the previous section, Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. The `Filter` DSL allows users to pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally. -Here is a list of all the callable Filters included in Flax NNX and their DSL literals -(when available): +Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available): | Literal | Callable | Description | @@ -96,10 +98,14 @@ Here is a list of all the callable Filters included in Flax NNX and their DSL li | | `All(*filters)` | Matches values that match all of the inner `filters` | | | `Not(filter)` | Matches values that do not match the inner `filter` | -Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters -and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can -use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes` -to specify how `model`'s various substates should be vectorized: + +Let's check out the DSL in action by using [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) as an example. Consider the following: + +1) You want to vectorize all parameters; +2) Apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis; and +3) Broadcast the rest. + +To do this, you can use the following `Filter`s to define a `nnx.StateAxes` object that you can pass to `nnx.vmap`'s `in_axes` to specify how the `model`'s various sub-states should be vectorized: ```{code-cell} ipython3 state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None}) @@ -109,10 +115,9 @@ def forward(model, x): ... ``` -Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` -expands to `Everything()`. +Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`. -If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`: +If you wish to manually convert literal into a predicate, you can use [`nnx.filterlib.to_predicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html#flax.nnx.filterlib.to_predicate): ```{code-cell} ipython3 is_param = nnx.filterlib.to_predicate(nnx.Param) @@ -126,22 +131,22 @@ print(f'{nothing = }') print(f'{params_or_dropout = }') ``` -## Grouping States +## Grouping `State`s -With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas: +With the knowledge of `Filter`s from previous sections at hand, let's learn how to roughly implement [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). Here are the key ideas: -* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node. -* Convert all the filters to predicates. +* Use `nnx.graph.flatten` to get the [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) and [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) representation of the node. +* Convert all the `Filter`s to predicates. * Use `State.flat_state` to get the flat representation of the state. * Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates. -* Use `State.from_flat_state` to convert the flat states to nested `State`s. +* Use `State.from_flat_state` to convert the flat states to nested [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. ```{code-cell} ipython3 from typing import Any KeyPath = tuple[nnx.graph.Key, ...] def split(node, *filters): - graphdef, state = nnx.graph.flatten(node) + graphdef, state, _ = nnx.graph.flatten(node) predicates = [nnx.filterlib.to_predicate(f) for f in filters] flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates] @@ -158,7 +163,7 @@ def split(node, *filters): ) return graphdef, *states -# lets test it... +# Let's test it. foo = Foo() graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat) @@ -167,12 +172,14 @@ print(f'{params = }') print(f'{batch_stats = }') ``` -One very important thing to note is that **filtering is order-dependent**. The first filter that -matches a value will keep it, therefore you should place more specific filters before more general -filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` -object that contains both types of parameters, if we try to split the `Param`s before the -`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group -will be empty because all `SpecialParam`s are also `Param`s: +**Note:*** It's very important to know that **filtering is order-dependent**. The first `Filter` that matches a value will keep it, and therefore you should place more specific `Filter`s before more general `Filter`s. + +For example, as demonstrated below, if you: + +1) Create a `SpecialParam` type that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param), and a `Bar` object (subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)) that contains both types of parameters; and +2) Try to split the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s before the `SpecialParam`s + +then all the values will be placed in the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) group, and the `SpecialParam` group will be empty because all `SpecialParam`s are also [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s: ```{code-cell} ipython3 class SpecialParam(nnx.Param): @@ -190,7 +197,7 @@ print(f'{params = }') print(f'{special_params = }') ``` -Reversing the order will make sure that the `SpecialParam` are captured first +And reversing the order will ensure that the `SpecialParam` are captured first: ```{code-cell} ipython3 graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct! diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index e326f80585..51e7480c84 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -19,7 +19,7 @@ "- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.\n", " - [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).\n", " - [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`\n", - " - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.\n", + " - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.\n", "\n", "## Setup\n", "\n", diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md index ba422da013..dc1e103ea6 100644 --- a/docs_nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -23,7 +23,7 @@ In this guide you will learn about: - The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state. - [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef). - [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update` - - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. + - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. ## Setup