From b3e65f81d578d99a01f52c4f28ffbc341bcb97f6 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 7 Oct 2024 00:28:04 +0000 Subject: [PATCH] Upgrade Flax NNX Bridge guide --- docs_nnx/guides/bridge_guide.ipynb | 680 ++++++++++++++++++++--------- docs_nnx/guides/bridge_guide.md | 522 +++++++++++++--------- 2 files changed, 777 insertions(+), 425 deletions(-) diff --git a/docs_nnx/guides/bridge_guide.ipynb b/docs_nnx/guides/bridge_guide.ipynb index e41836a93e..481bd19aa8 100644 --- a/docs_nnx/guides/bridge_guide.ipynb +++ b/docs_nnx/guides/bridge_guide.ipynb @@ -4,22 +4,29 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Use Flax NNX and Linen together\n", + "# Use Flax NNX and Linen together via `nnx.bridge`\n", "\n", - "This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API.\n", + "This guide is designed to assist existing Flax users who want to mix Flax NNX and Flax Linen `Module`s in their codebase. Bridging NNX and Linen code is made possible with the help of the [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) API. This document should enable you to move to and try out Flax NNX at your own pace, and leverage \"the best of both worlds\". This can be particularly helpful if you:\n", "\n", - "This will be helpful if you:\n", + "* Want to migrate your codebase to [Flax NNX](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) from [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) gradually, one `Module` at a time; and/or\n", + "* Have an external dependency that has already been moved to Flax NNX, but you have not done so. Alternatively, it may still be in Flax Linen while you've moved your code to Flax NNX.\n", "\n", - "* Want to migrate your codebase to NNX gradually, one module at a time;\n", - "* Have external dependency that already moved to NNX but you haven't, or is still in Linen while you've moved to NNX.\n", + "You will also learn how to resolve certain caveats of interoperating both Flax Linen and Flax NNX APIs. The guide will also teach you some aspects of how Flax Linen and NNX APIs are fundamentally different.\n", "\n", - "We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different.\n", + "Table of contents:\n", "\n", - "**Note**:\n", + "- A sub-`Module` is all you need\n", + "- Basics\n", + " - Flax Linen to NNX with `nnx.bridge.lazy_init`\n", + " - Flax NNX to Linen with `nnx.bridge.ToLinen`\n", + "- Handling the JAX PRNG keys\n", + "- Flax NNX variable types vs Flax Linen collections\n", + "- Partition metadata\n", + "- Lifted transformations - go ahead and do it\n", "\n", - "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n", + "**Note**: Since this guide describes how to glue a [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) with a [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module), if you need to _migrate_ an existing Linen `Module` (a.k.a. `nn.Module`) to an NNX `Module`, check out the [Migrate from Haiku to Flax (Linen and NNX)](https://flax.readthedocs.io/en/latest/guides/haiku_to_flax.html) guide. In addition, all [built-in Flax Linen layers](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html) should have [equivalent Flax NNX versions](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).\n", "\n", - "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)." + "First, let's import some necessary dependencies:" ] }, { @@ -44,44 +51,107 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Submodule is all you need\n", + "## A sub-`Module` is all you need\n", "\n", - "A Flax model is always a tree of modules - either old Linen modules (`flax.linen.Module`, usually written as `nn.Module`) or NNX modules (`nnx.Module`). \n", + "A Flax model is a [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree) of `Module`s - either an old [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) (usually written as `nn.Module`) or a new [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module).\n", "\n", - "An `nnx.bridge` wrapper glues the two types together, in both ways:\n", + "The [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) wrapper API enables you to glue these two types of `Module`s together in two ways using:\n", "\n", - "* `nnx.bridge.ToNNX`: Convert a Linen module to NNX, so that it can be a submodule of another NNX module, or stand alone to be trained in NNX-style training loops.\n", - "* `nnx.bridge.ToLinen`: Vice versa, convert a NNX module to Linen.\n", + "* [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX): Converts a `flax.linen.Module` to NNX, so that it can be a sub-`Module` of another `flax.nnx.Module`, or a standalone `Module` to be trained in NNX style training loops.\n", + "* [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen): The opposite of `nnx.bridge.ToNNX` - it converts a `flax.nnx.Module` to `flax.linen.Module`.\n", "\n", - "This means you can move in either top-down or bottom-up behavior: convert the whole Linen module to NNX, then gradually move down, or convert all the lower level modules to NNX then move up.\n" + "Therefore, you can convert the entire `flax.linen.Module` to Flax NNX, and then gradually “move down” (the “top-down” way), or convert all the lower-level `flax.linen.Module`s to Flax NNX and then “move up” (the “bottom-up” way)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## The Basics\n", + "## Basics\n", "\n", - "There are two fundamental difference between Linen and NNX modules:\n", + "There are two fundamental differences between [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) and [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module):\n", "\n", - "* **Stateless vs. stateful**: Linen module instances are stateless: variables are returned from a purely functional `.init()` call and managed separately. NNX modules, however, owns its variables as instance attributes.\n", + "* **Stateless vs stateful**:\n", + " - Flax Linen `Module` instances are stateless: Variables are returned from a purely functional `Module.init()` call and managed separately.\n", + " - Flax NNX `Module`s, however, own their variables as instance attributes.\n", "\n", - "* **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input.\n", + "* **Lazy vs eager**:\n", + " - Flax Linen `Module`s only allocate space to create variables when they actually see their input.\n", + " - In comparison, Flax NNX `Module` instances create their [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) the moment they are instantiated without seeing a sample input.\n", "\n", - "With that in mind, let's look at how the `nnx.bridge` wrappers tackle the differences." + "With that in mind, let's investigate how the [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) wrappers tackle these differences.\n", + "\n", + "### Basics: Flax Linen to NNX with `nnx.bridge.lazy_init`\n", + "\n", + "Since `flax.linen.Module`s may require an input to create variables, the Flax team semi-formally supports lazy initialization in the `flax.nnx.Module`s converted from Flax Linen. The Flax Linen variables are created when you give it a sample input. For you, it's calling [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init) (`nnx.bridge.ToNNX.lazy_init`) where you call `module.init()` in the Flax Linen code.\n", + "\n", + "> **Note:** To inspect all `flax.nnx.Module` variables and state, You can call [`nnx.display`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/visualization.html#flax.nnx.display)." ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, + "id": "a3db6428", "metadata": {}, + "outputs": [], "source": [ - "### Linen -> NNX\n", + "class LinenDot(nn.Module):\n", + " out_dim: int\n", + " w_init: Callable[..., Any] = nn.initializers.lecun_normal()\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " # Flax Linen might need the input shape to create the weight!\n", + " w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))\n", + " return x @ w\n", "\n", - "Since Linen modules may require an input to create variables, we semi-formally supported lazy initialization in the NNX modules converted from Linen. The Linen variables are created when you give it a sample input.\n", + "x = jax.random.normal(jax.random.key(42), (4, 32))\n", + "model = bridge.ToNNX(LinenDot(64), rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen\n", + "bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen\n", + "y = model(x) # => `y = model.apply(var, x)` in Linen\n", "\n", - "For you, it's calling `nnx.bridge.lazy_init()` where you call `module.init()` in Linen code.\n", + "nnx.display(model)\n", "\n", - "(Note: you can call `nnx.display` upon any NNX module to inspect all its variables and state.)" + "# In-place swap your weight array and the model still works!\n", + "model.params['w'].value = jax.random.normal(jax.random.key(1), (32, 64))\n", + "assert not jnp.allclose(y, model(x))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ToNNX(\n", + " module=LinenDot(\n", + " # attributes\n", + " out_dim = 64\n", + " w_init = init\n", + " ),\n", + " rngs=Rngs(\n", + " default=RngStream(\n", + " key=RngKey(\n", + " value=Array((), dtype=key) overlaying:\n", + " [0 0],\n", + " tag='default'\n", + " ),\n", + " count=RngCount(\n", + " value=Array(1, dtype=uint32),\n", + " tag='default'\n", + " )\n", + " )\n", + " ),\n", + " linen_collections=('params',),\n", + " params={'w': Param(\n", + " value=Array(shape=(32, 64), dtype=float32)\n", + " )}\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "982639de", + "metadata": {}, + "source": [ + "The [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init) method also works even if the top-level `Module` is a pure-NNX one, so you can perform \"sub-moduling\" as you wish:" ] }, { @@ -115,33 +185,62 @@ } ], "source": [ - "class LinenDot(nn.Module):\n", - " out_dim: int\n", - " w_init: Callable[..., Any] = nn.initializers.lecun_normal()\n", - " @nn.compact\n", + "class NNXOuter(nnx.Module):\n", + " def __init__(self, out_dim: int, rngs: nnx.Rngs):\n", + " self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)\n", + " self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))\n", + "\n", " def __call__(self, x):\n", - " # Linen might need the input shape to create the weight!\n", - " w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))\n", - " return x @ w\n", + " return self.dot(x) + self.b\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "model = bridge.ToNNX(LinenDot(64), \n", - " rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen\n", - "bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen\n", - "y = model(x) # => `y = model.apply(var, x)` in Linen\n", - "\n", - "nnx.display(model)\n", - "\n", - "# In-place swap your weight array and the model still works!\n", - "model.w.value = jax.random.normal(jax.random.key(1), (32, 64))\n", - "assert not jnp.allclose(y, model(x))" + "model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit them into one line too\n", + "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "`nnx.bridge.lazy_init` also works even if the top-level module is a pure-NNX one, so you can do sub-moduling as you wish:" + " NNXOuter(\n", + " dot=ToNNX(\n", + " module=LinenDot(\n", + " # attributes\n", + " out_dim = 64\n", + " w_init = init\n", + " ),\n", + " rngs=Rngs(\n", + " default=RngStream(\n", + " key=RngKey(\n", + " value=Array((), dtype=key) overlaying:\n", + " [0 0],\n", + " tag='default'\n", + " ),\n", + " count=RngCount(\n", + " value=Array(1, dtype=uint32),\n", + " tag='default'\n", + " )\n", + " )\n", + " ),\n", + " linen_collections=('params',),\n", + " params={'w': Param(\n", + " value=Array(shape=(32, 64), dtype=float32)\n", + " )}\n", + " ),\n", + " b=Param(\n", + " value=Array(shape=(1, 64), dtype=float32)\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "a5bc171f", + "metadata": {}, + "source": [ + "The Flax Linen weight is already converted to a typical Flax NNX variable ([`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)), which is a thin wrapper of the actual [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) value within. Here, `w` is an [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) because it belongs to the `params` collection of `LinenDot` `flax.linen.Module`.\n", + "\n", + "Different collections and types are covered in more detail in the [NNX Variable <-> Linen Collections](#variable-types-vs-collections) section. Right now, you just need to know that they are converted to Flax `nnx.Variable`s like native ones." ] }, { @@ -175,43 +274,15 @@ } ], "source": [ - "class NNXOuter(nnx.Module):\n", - " def __init__(self, out_dim: int, rngs: nnx.Rngs):\n", - " self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)\n", - " self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))\n", - "\n", - " def __call__(self, x):\n", - " return self.dot(x) + self.b\n", - "\n", - "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line\n", - "nnx.display(model)" + "assert isinstance(model.dot.params['w'], nnx.Param)\n", + "assert isinstance(model.dot.params['w'].value, jax.Array)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The Linen weight is already converted to a typical NNX variable, which is a thin wrapper of the actual JAX array value within. Here, `w` is an `nnx.Param` because it belongs to the `params` collection of `LinenDot` module.\n", - "\n", - "We will talk more about different collections and types in the [NNX Variable <-> Linen Collections](#variable-types-vs-collections) section. Right now, just know that they are converted to NNX variables like native ones." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(model.dot.w, nnx.Param)\n", - "assert isinstance(model.dot.w.value, jax.Array)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If you create this model witout using `nnx.bridge.lazy_init`, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside `ToNNX`) will not." + "If you create this model without using [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init), the Flax [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) defined outside will be initialized as usual, but the Flax Linen part (that is wrapped inside of [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX)) will not." ] }, { @@ -249,6 +320,39 @@ "nnx.display(partial_model)" ] }, + { + "cell_type": "markdown", + "id": "7a551761", + "metadata": {}, + "source": [ + " NNXOuter(\n", + " dot=ToNNX(\n", + " module=LinenDot(\n", + " # attributes\n", + " out_dim = 64\n", + " w_init = init\n", + " ),\n", + " rngs=Rngs(\n", + " default=RngStream(\n", + " key=RngKey(\n", + " value=Array((), dtype=key) overlaying:\n", + " [0 0],\n", + " tag='default'\n", + " ),\n", + " count=RngCount(\n", + " value=Array(1, dtype=uint32),\n", + " tag='default'\n", + " )\n", + " )\n", + " ),\n", + " linen_collections=()\n", + " ),\n", + " b=Param(\n", + " value=Array(shape=(1, 64), dtype=float32)\n", + " )\n", + " )" + ] + }, { "cell_type": "code", "execution_count": 6, @@ -288,11 +392,49 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### NNX -> Linen\n", + " NNXOuter(\n", + " dot=ToNNX(\n", + " module=LinenDot(\n", + " # attributes\n", + " out_dim = 64\n", + " w_init = init\n", + " ),\n", + " rngs=Rngs(\n", + " default=RngStream(\n", + " key=RngKey(\n", + " value=Array((), dtype=key) overlaying:\n", + " [0 0],\n", + " tag='default'\n", + " ),\n", + " count=RngCount(\n", + " value=Array(1, dtype=uint32),\n", + " tag='default'\n", + " )\n", + " )\n", + " ),\n", + " linen_collections=('params',),\n", + " params={'w': Param(\n", + " value=Array(shape=(32, 64), dtype=float32)\n", + " )}\n", + " ),\n", + " b=Param(\n", + " value=Array(shape=(1, 64), dtype=float32)\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "d6c2dffa", + "metadata": {}, + "source": [ + "### Basics: Flax NNX to Linen `nnx.bridge.ToLinen`\n", "\n", - "To convert an NNX module to Linen, you should forward your creation arguments to `bridge.ToLinen` and let it handle the actual creation process.\n", + "To convert a `flax.nnx.Module` to Flax Linen, you should forward your creation arguments to [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) and let it handle the actual creation process.\n", "\n", - "This is because NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical `init` and `apply` process involves multiple creation of them. So `bridge.to_linen` will handle the actual module creation and make sure no memory is allocated twice." + "This is because:\n", + "- The `flax.nnx.Module` instance initializes all its variables eagerly when it is created, which consumes memory and compute.\n", + "- On the other hand, `flax.linen.Module`s are stateless, and the typical `init` and `apply` process involves multiple creation of them. Therefore, [`nnx.bridge.to_linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.to_linen) will handle the actual `Module` creation and make sure no memory is allocated twice." ] }, { @@ -313,27 +455,35 @@ "source": [ "class NNXDot(nnx.Module):\n", " def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):\n", - " self.w = nnx.Param(nnx.initializers.lecun_normal()(\n", - " rngs.params(), (in_dim, out_dim)))\n", + " self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (in_dim, out_dim)))\n", " def __call__(self, x: jax.Array):\n", " return x @ self.w\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "# Pass in the arguments, not an actual module\n", - "model = bridge.to_linen(NNXDot, 32, out_dim=64)\n", + "model = bridge.to_linen(NNXDot, 32, out_dim=64) # <- Pass in the arguments, not an actual module\n", "variables = model.init(jax.random.key(0), x)\n", "y = model.apply(variables, x)\n", "\n", "print(list(variables.keys()))\n", - "print(variables['params']['w'].shape) # => (32, 64)\n", - "print(y.shape) # => (4, 64)\n" + "print(variables['params']['w'].value.shape) # => (32, 64)\n", + "print(y.shape) # => (4, 64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note that `ToLinen` modules need to track an extra variable collection - `nnx` - for the static metadata of the underlying NNX module." + " ['nnx', 'params']\n", + " (32, 64)\n", + " (4, 64)" + ] + }, + { + "cell_type": "markdown", + "id": "de1b26a5", + "metadata": {}, + "source": [ + "Note that [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) `Module`s need to track an extra variable collection - `nnx` - for the static metadata of the underlying `nnx.Module`." ] }, { @@ -358,7 +508,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`bridge.to_linen` is actually a convenience wrapper around the Linen module `bridge.ToLinen`. Most likely you won't need to use `ToLinen` directly at all, unless you are using one of the built-in arguments of `ToLinen`. For example, if your NNX module doesn't want to be initialized with RNG handling:" + " " + ] + }, + { + "cell_type": "markdown", + "id": "c8880236", + "metadata": {}, + "source": [ + "[`nnx.bridge.to_linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.to_linen) is actually a convenience wrapper around the Flax Linen Module [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen). Most likely you won't need to use `nnx.bridge.ToLinen directly at all, unless you are using one of the built-in arguments of `nnx.bridge.ToLinen`. For example, if your `nnx.Module` doesn't want to be initialized with PRNG handling:" ] }, { @@ -373,8 +531,7 @@ " def __call__(self, x):\n", " return x + self.constant\n", "\n", - "# You have to use `skip_rng=True` because this module's `__init__` don't\n", - "# take `rng` as argument\n", + "# You have to use `skip_rng=True` because your module `__init__` don't take `rng` as an argument.\n", "model = bridge.ToLinen(NNXAddConstant, skip_rng=True)\n", "y, var = model.init_with_output(jax.random.key(0), x)" ] @@ -383,7 +540,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module. " + "You may notice that you need to an additional `.value` to access this Flax `w` param. This is because all Flax NNX variables will be wrapped with an [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) class, which will allow it to be annotated with various information, such as its partitioning. This was translated into an equivalent `nnx.bridge.NNXMeta` wrapper.\n", + "\n", + "If you use [Partition metadata in Flax Linen](https://flax-linen.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html), you can learn more about how that works in Flax NNX in [Partition metadata section](#partition-metadata) below." ] }, { @@ -399,6 +558,31 @@ ] } ], + "source": [ + "print(type(variables['params']['w'])) # => nnx.bridge.NNXMeta\n", + "print(type(variables['params']['w'].value)) # => jax.Array" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " \n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similar to [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX), you can use [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) to create a sub-`Module` of another `flax.linen.Module`." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], "source": [ "class LinenOuter(nn.Module):\n", " out_dim: int\n", @@ -411,7 +595,7 @@ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = LinenOuter(out_dim=64)\n", "y, variables = model.init_with_output(jax.random.key(0), x)\n", - "w, b = variables['params']['ToLinen_0']['w'], variables['params']['b']\n", + "w, b = variables['params']['ToLinen_0']['w'].value, variables['params']['b']\n", "print(w.shape, b.shape, y.shape)" ] }, @@ -419,37 +603,52 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Handling RNG keys\n", - "\n", - "All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys.\n", - "\n", - "Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves." + " (32, 64) (1, 64) (4, 64)" ] }, { "cell_type": "markdown", + "id": "3ded9b4d", "metadata": {}, "source": [ - "### Linen to NNX\n", + "## Handling the JAX PRNG keys\n", + "\n", + "All Flax `Module`s - in [Linen](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) or [NNX](https://flax.readthedocs.io/en/latest/guides/randomness.html) - can automatically handle the JAX [pseudorandom number generator (PRNG)](https://jax.readthedocs.io/en/latest/random-numbers.html) keys for variable creation and random layers like dropouts. However, the specific logics of PRNG key splitting are different, so you cannot generate the same params between Linen and NNX `Module`s, even if you pass in the same keys.\n", + "\n", + "Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves.\n", + "\n", + "> **Note:** To refresh your memory of PRNG key handling, review [JAX PRNG 101](https://jax.readthedocs.io/en/latest/random-numbers.html), [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng), [Flax NNX Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html), and [Flax Linen Randomness and PRNGs](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html).\n", + "\n", + "### PRNG keys: Flax Linen to NNX - Enjoy the stateful benefits!\n", "\n", - "If you convert a Linen module to NNX, you enjoy the stateful benefit and don't need to pass in extra RNG keys on every module call. You can use always `nnx.reseed` to reset the RNG state within." + "If you convert a Flax Linen `Module` to NNX, you can enjoy the stateful benefits and don't need to pass in extra PRNG keys on every `nnx.Module` call. And you can use always [`nnx.reseed`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.reseed) to reset the PRNG state within." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The RNG key in state: Array((), dtype=key) overlaying:\n", + "[1428664606 3351135085]\n", + "Number of key splits: 0\n", + "Number of key splits after y2: 2\n" + ] + } + ], "source": [ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))\n", - "# We don't really need to call lazy_init because no extra params were created here,\n", - "# but it's a good practice to always add this line.\n", - "bridge.lazy_init(model, x)\n", + "bridge.lazy_init(model, x) # You don't really need this because no extra params were created here,\n", + " # but it's a good practice to always add this line.\n", "y1, y2 = model(x), model(x)\n", "assert not jnp.allclose(y1, y2) # Two runs yield different outputs!\n", "\n", - "# Reset the dropout RNG seed, so that next model run will be the same as the first.\n", + "# Reset the dropout PRNG seed, so that the next model run will be the same as the first.\n", "nnx.reseed(model, dropout=0)\n", "assert jnp.allclose(y1, model(x))" ] @@ -458,56 +657,47 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### NNX to Linen\n", + "### PRNG keys: Flax NNX to Linen - Two handling style options\n", "\n", - "If you convert an NNX module to Linen, the underlying NNX module's RNG states will still be part of the top-level `variables`. On the other hand, Linen `apply()` call accepts different RNG keys on each call, which resets the internal Linen environment and allow different random data to be generated.\n", + "If you convert a Flax NNX `Module` to Linen, the underlying `flax.nnx.Module's PRNG states will still be part of the top-level variables. On the other hand, the `flax.linen.Module.apply()` call accepts different PRNG keys on each call, which _resets the internal Flax Linen environment and allows different random data to be generated_.\n", "\n", - "Now, it really depends on whether your underlying NNX module generates new random data from its RNG state, or from the passed-in argument. Fortunately, `nnx.Dropout` supports both - using passed-in keys if there is any, and use its own RNG state if not.\n", + "Now, it really depends on whether your underlying NNX module generates new random data from its PRNG state, or from the passed-in argument. Fortunately, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) supports both - using passed-in keys if there is any, and using its own PRNG state if not.\n", "\n", - "And this leaves you with two style options of handling the RNG keys: \n", + "And this leaves you with two style options of handling the PRNG keys:\n", "\n", - "* The NNX style (recommended): Let the underlying NNX state manage the RNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs `ToLinen`.\n", + "* The Flax NNX style (recommended): Let the underlying NNX state manage the PRNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen).\n", + "* The Flax Linen style: Just pass different PRNG keys for every `apply()` call.\n", "\n", - "* The Linen style: Just pass different RNG keys for every `apply()` call." + "> **Note:** Don't forget - there are [Flax NNX Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html), and [Flax Linen Randomness and PRNGs](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) tutorials that can help you understand PRNG handling in Flax." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, + "id": "d175b29a", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The RNG key in state: Array((), dtype=key) overlaying:\n", - "[1428664606 3351135085]\n", - "Number of key splits: 0\n", - "Number of key splits after y2: 2\n" - ] - } - ], + "outputs": [], "source": [ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = bridge.to_linen(nnx.Dropout, rate=0.5)\n", "variables = model.init({'dropout': jax.random.key(0)}, x)\n", "\n", - "# The NNX RNG state was stored inside `variables`\n", + "# The Flax NNX PRNG state was stored inside `variables`.\n", "print('The RNG key in state:', variables['RngKey']['rngs']['dropout']['key'].value)\n", "print('Number of key splits:', variables['RngCount']['rngs']['dropout']['count'].value)\n", "\n", - "# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`\n", + "# Flax NNX style: Must set `RngCount` as mutable and update the variables after every `apply`.\n", "y1, updates = model.apply(variables, x, mutable=['RngCount'])\n", "variables |= updates\n", "y2, updates = model.apply(variables, x, mutable=['RngCount'])\n", "variables |= updates\n", "print('Number of key splits after y2:', variables['RngCount']['rngs']['dropout']['count'].value)\n", - "assert not jnp.allclose(y1, y2) # Every call yields different output!\n", + "assert not jnp.allclose(y1, y2) # Every call yields a different output!\n", "\n", - "# Linen style: Just pass different RNG keys for every `apply()` call.\n", + "# Flax Linen style: Just pass different PRNG keys for every `apply()` call.\n", "y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})\n", "y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})\n", - "assert not jnp.allclose(y3, y4) # Every call yields different output!\n", + "assert not jnp.allclose(y3, y4) # Every call yields a different output!\n", "y5 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})\n", "assert jnp.allclose(y3, y5) # When you use same top-level RNG, outputs are same" ] @@ -516,24 +706,30 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## NNX variable types vs. Linen collections\n", - "\n", - "When you want to group some variables as one category, in Linen you use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types.\n", - "\n", - "Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that `ToNNX` and `ToLinen` can do the conversion automatically. \n", - "\n", - "Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of NNX variable type and Linen collection names using `nnx.register_variable_name_type_pair`." + " The RNG key in state: Array((), dtype=key) overlaying:\n", + " [1428664606 3351135085]\n", + " Number of key splits: 0\n", + " Number of key splits after y2: 2" ] }, { "cell_type": "markdown", + "id": "252b478b", "metadata": {}, "source": [ - "### Linen to NNX\n", + "## Flax NNX variable types vs Flax Linen collections\n", + "\n", + "When you want to group certain variables in one category, in Flax Linen you use different collections. In Flax NNX, because all variables shall be top-level Python attributes, you use different variable types.\n", "\n", - "For any collection of your Linen module, `ToNNX` will convert all its endpoint arrays (aka. leaves) to a subtype of `nnx.Variable`, either from registry or automatically created on-the-fly. \n", + "Therefore, when mixing Flax Linen and NNX `Module`s, Flax must know the 1-to-1 mapping between Flax Linen collections and Flax NNX variable types, so that [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) and [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) can do the conversion automatically.\n", "\n", - "(However, we still keep the whole collection as one class attribute, because Linen modules may have duplicated names over different collections.)" + "Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of Flax NNX variable types and Flax Linen collection names using [`nnx.register_variable_name_type_pair`](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/bridge/variables.html).\n", + "\n", + "### Variables and collections: Flax Linen to NNX\n", + "\n", + "For any collection of your Linen module, [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) will convert all its endpoint arrays (a.k.a. [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree) [leaves](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#mistaking-pytree-nodes-for-leaves)) to a subtype of [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), either from registry or automatically created on-the-fly.\n", + "\n", + "> **Note:** However, you still keep the entire collection(s) as one class attribute, because `flax.linen.Module`s may have duplicated names over different collections." ] }, { @@ -581,22 +777,43 @@ "\n", "x = jax.random.normal(jax.random.key(42), (2, 4))\n", "model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x)\n", - "print(model.w) # Of type `nnx.Param` - note this is still under attribute `params`\n", - "print(model.b) # Of type `nnx.Param`\n", - "print(model.count) # Of type `counter` - auto-created type from the collection name\n", - "print(type(model.count))\n", + "print(model.params['w']) # Of type `nnx.Param` - note this is still under attribute `params`\n", + "print(model.params['b']) # Of type `nnx.Param`\n", + "print(model.counter['count']) # Of type `counter` - an auto-created dummy type from the name \"counter\"\n", + "print(type(model.counter['count']))\n", "\n", - "y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger\n", - "print(model.dot_sum) # Of type `nnx.Intermediates`" + "y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger\n", + "print(model.intermediates['dot_sum']) # Of type `nnx.Intermediates`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can quickly separate different types of NNX variables apart using `nnx.split`.\n", - "\n", - "This can be handy when you only want to set some variables as trainable." + " Param(\n", + " value=Array([[ 0.35401407, 0.38010964, -0.20674096],\n", + " [-0.7356256 , 0.35613298, -0.5099556 ],\n", + " [-0.4783049 , 0.4310735 , 0.30137998],\n", + " [-0.6102254 , -0.2668519 , -1.053598 ]], dtype=float32)\n", + " )\n", + " Param(\n", + " value=Array([0., 0., 0.], dtype=float32)\n", + " )\n", + " counter(\n", + " value=Array(0, dtype=int32)\n", + " )\n", + " \n", + " (Intermediate(\n", + " value=Array(6.932987, dtype=float32)\n", + " ),)" + ] + }, + { + "cell_type": "markdown", + "id": "149f3886", + "metadata": {}, + "source": [ + "You can quickly separate different types of Flax NNX variables apart using [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). This can be handy when you only want to set certain variables as trainable." ] }, { @@ -615,24 +832,33 @@ } ], "source": [ - "# Separate variables of different types with nnx.split\n", - "CountType = type(model.count)\n", + "# Separate variables of different types with `nnx.split`.\n", + "CountType = type(model.counter['count'])\n", "static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...)\n", - "print('All Params:', list(params.keys()))\n", - "print('All Counters:', list(counter.keys()))\n", + "print('All Params:', list(params['params'].keys()))\n", + "print('All Counters:', list(counter['counter'].keys()))\n", "print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))\n", "\n", - "model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time\n", - "y = model(x, mutable=True) # still works!" + "model = nnx.merge(static, params, counter, the_rest) # You can merge them (back) anytime." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### NNX to Linen\n", + " All Params: ['b', 'w']\n", + " All Counters: ['count']\n", + " All the rest (intermediates and RNG keys): ['intermediates', 'rngs']" + ] + }, + { + "cell_type": "markdown", + "id": "a694b265", + "metadata": {}, + "source": [ + "### Variables and collections: Flax NNX to Linen\n", "\n", - "If you define custom NNX variable types, you should register their names with `nnx.register_variable_name_type_pair` so that they go to the desired collections." + "If you define custom Flax NNX variable types, you should register their names with [`nnx.register_variable_name_type_pair`](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/bridge/variables.html) so that they go to the desired collections." ] }, { @@ -678,26 +904,36 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Partition metadata\n", - "\n", - "Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded.\n", - "\n", - "In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.\n", - "\n", - "The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX)." + " All Linen collections: ['nnx', 'LoRAParam', 'params', 'counts']\n", + " {'w': NNXMeta(var_type=, value=Array([[ 0.2916921 , 0.22780475, 0.06553137],\n", + " [ 0.17487915, -0.34043145, 0.24764155],\n", + " [ 0.6420431 , 0.6220095 , -0.44769976],\n", + " [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32), metadata={'get_value_hooks': (), 'set_value_hooks': (), 'create_value_hooks': (), 'add_axis_hooks': (), 'remove_axis_hooks': ()})}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Linen to NNX\n", + "## Partition metadata\n", + "\n", + "Flax uses a metadata wrapper box over the raw [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) to annotate how a variable should be sharded.\n", + "\n", + "In Flax Linen, this is an optional feature that is triggered by using [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) on initializers. In Flax NNX, since all Flax NNX variables are wrapped by [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) class anyway, that class will hold the sharding annotations too.\n", + "\n", + "> **Note:** If you are new to `jax.Array`s and _data sharding_, go to [Key concepts](https://jax.readthedocs.io/en/latest/key-concepts.html#array-devices-and-sharding) and [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html#sharded-computation) on the JAX documentation site.\n", + "\n", + "Both [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) and [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) will automatically convert the sharding annotations if you use the built-in annotation methods, such as [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) or [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning).\n", + "\n", + "> **Note:** To get more familiarized with sharding metadata with Flax and JAX, refer to Flax NNX’s [Scale up](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide, JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html#sharded-computation), and the Flax Linen [Scale up](https://flax-linen.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) guide.\n", + "\n", + "### Partition metadata: Flax Linen to NNX\n", "\n", - "Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to `nnx.Variable`s that wraps the true JAX array within. \n", + "Even if you are not using any partition metadata in your Flax Linen `Module`, the variable [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) will be converted to [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) that wraps the true `jax.Array` within.\n", "\n", - "If you use `nn.with_partitioning` to annotate your Linen module's variables, the annotation will be converted to a `.sharding` field in the corresponding `nnx.Variable`. \n", + "If you use [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) to annotate your Flax Linen `Module` variables, the annotation will be converted to the `.sharding` field in the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n", "\n", - "You can then use `nnx.with_sharding_constraint` to explicitly put the arrays into the annotated partitions within a `jax.jit`-compiled function, to initialize the whole model with every array at the right sharding." + "You can then use [`nnx.with_sharding_constraint`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_sharding_constraint) to explicitly put the arrays into the annotated partitions within a [`jax.jit`](https://jax.readthedocs.io/en/latest/jit-compilation.html)-compiled function, to initialize the whole model with every array at the right sharding." ] }, { @@ -721,15 +957,13 @@ " out_dim: int\n", " @nn.compact\n", " def __call__(self, x):\n", - " w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), \n", - " ('in', 'out')), \n", + " w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')),\n", " (x.shape[-1], self.out_dim))\n", " return x @ w\n", "\n", "@nnx.jit\n", "def create_sharded_nnx_module(x):\n", - " model = bridge.lazy_init(\n", - " bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)\n", + " model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)\n", " state = nnx.state(model)\n", " sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))\n", " nnx.update(model, sharded_state)\n", @@ -737,28 +971,38 @@ "\n", "\n", "print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')\n", - "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), \n", - " axis_names=('in', 'out'))\n", + "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), axis_names=('in', 'out'))\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "with mesh:\n", " model = create_sharded_nnx_module(x)\n", "\n", - "print(type(model.w)) # `nnx.Param`\n", - "print(model.w.sharding) # The partition annotation attached with `w`\n", - "print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh" + "print(type(model.params['w'])) # `nnx.Param`\n", + "print(model.params['w'].sharding) # The partition annotation attached with the weight `w`\n", + "print(model.params['w'].value.sharding) # The underlying JAX array is sharded across the 2x4 mesh" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### NNX to Linen\n", + " We have 8 fake JAX devices now to partition this model...\n", + " \n", + " ('in', 'out')\n", + " GSPMDSharding({devices=[2,4]<=[8]})" + ] + }, + { + "cell_type": "markdown", + "id": "5c3ca5f1", + "metadata": {}, + "source": [ + "### Partition metadata: Flax NNX to Linen\n", "\n", - "If you are not using any metadata feature of the `nnx.Variable` (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don't need to worry about it.\n", + "Since all Flax NNX variables are wrapped with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) box, the converted Flax Linen module will have all variables boxed too. There is a default partition metadata class called [`flax.nnx.bridge.NNXMeta`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.NNXMeta) for storing these converted NNX variables.\n", "\n", - "But if you did add sharding annotations to your NNX variables, `ToLinen` will convert them to a default Linen partition metadata class called `bridge.NNXMeta`, retaining all the metadata you put into the NNX variable.\n", + "[`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) will automatically shard the array with the annotation if it is called within a [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) context, so you don't need to do `with_sharding_constraint` yourself.\n", "\n", - "Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the raw JAX array tree." + "Similar to any Flax Linen metadata wrappers, you can use [`flax.linen.meta.unbox`](https://github.com/google/flax/blob/5d31452889b8d106d7c722b5eaac14cb9784fec2/flax/core/meta.py#L160) to get the raw [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree)." ] }, { @@ -783,51 +1027,42 @@ " return x @ self.w\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "\n", - "@jax.jit\n", - "def create_sharded_variables(key, x):\n", - " model = bridge.to_linen(NNXDotWithParititioning, 32, 64)\n", - " variables = model.init(key, x)\n", - " # A `NNXMeta` wrapper of the underlying `nnx.Param`\n", - " assert type(variables['params']['w']) == bridge.NNXMeta\n", - " # The annotation coming from the `nnx.Param` => (in, out)\n", - " assert variables['params']['w'].metadata['sharding'] == ('in', 'out')\n", - " \n", - " unboxed_variables = nn.unbox(variables)\n", - " variable_pspecs = nn.get_partition_spec(variables)\n", - " assert isinstance(unboxed_variables['params']['w'], jax.Array)\n", - " assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out')\n", - " \n", - " sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint, \n", - " nn.unbox(variables),\n", - " nn.get_partition_spec(variables))\n", - " return sharded_vars\n", + "model = bridge.to_linen(NNXDotWithParititioning, 32, 64)\n", "\n", "with mesh:\n", - " variables = create_sharded_variables(jax.random.key(0), x)\n", + " variables = jax.jit(model.init)(jax.random.key(0), x)\n", "\n", - "# The underlying JAX array is sharded across the 2x4 mesh\n", - "print(variables['params']['w'].sharding)" + "print(type(variables['params']['w'])) # A `NNXMeta` wrapper of the underlying `nnx.Param`\n", + "print(variables['params']['w'].metadata['sharding']) # The annotation coming from the `nnx.Param`\n", + "print(variables['params']['w'].value.sharding) # The underlying JAX array is sharded across the 2x4 mesh\n", + "\n", + "unboxed = nn.unbox(variables)\n", + "print(type(unboxed['params']['w'])) # The raw jax.Array" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Lifted transforms\n", - "\n", - "In general, if you want to apply Linen/NNX-style lifted transforms upon an `nnx.bridge`-converted module, just go ahead and do it in the usual Linen/NNX syntax. \n", - "\n", - "For Linen-style transforms, note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases)" + " \n", + " ('in', 'out')\n", + " GSPMDSharding({devices=[2,4]<=[8]})\n", + " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Linen to NNX\n", + "## Lifted transformations - go ahead and do it\n", + "\n", + "In general, if you want to apply [Flax Linen-](https://flax-linen.readthedocs.io/en/latest/developer_notes/lift.html) or [Flax NNX style lifted transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) on an [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html)-converted `Module`, just go ahead and do it in the usual Flax Linen or NNX syntax.\n", "\n", - "NNX style lifted transforms are similar to JAX transforms, and they work on functions." + "For [Flax Linen style transforms](https://flax-linen.readthedocs.io/en/latest/developer_notes/lift.html), note that [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) is the top-level `Module` class, so you may want to just use it as the first argument of your transforms (which needs to be a `flax.linen.Module` class in most cases).\n", + "\n", + "### Lifted transforms: Flax Linen to NNX\n", + "\n", + "[Flax NNX style lifted transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) are similar to [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations), and they too work on functions." ] }, { @@ -862,7 +1097,7 @@ "x = jax.random.normal(jax.random.key(0), (4, 32))\n", "model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)\n", "\n", - "print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped\n", + "print(model.linen_dot.params['kernel'].shape) # (4, 32, 64) - first axis with dim 4 got `vmap`ped.\n", "y = model(x)\n", "print(y.shape)" ] @@ -871,11 +1106,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### NNX to Linen\n", - "\n", - "Note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases).\n", + " (4, 32, 64)\n", + " (4, 64)" + ] + }, + { + "cell_type": "markdown", + "id": "93419218", + "metadata": {}, + "source": [ + "### Lifted transforms: Flax NNX to Linen\n", "\n", - "Also, since `bridge.ToLinen` introduced this extra `nnx` collection, you need to mark it when using the axis-changing transforms (`linen.vmap`, `linen.scan`, etc) to make sure they are passed inside." + "As mentioned before, [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) is the top-level `Module` class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases). And, since `nnx.bridge.ToLinen` introduces this extra `nnx` collection, you need to mark it when using the axis-changing transforms ([`flax.linen.vmap`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.vmap), [`flax.linen.scan`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.scan), and so on) to make sure they are passed inside." ] }, { @@ -904,13 +1146,27 @@ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = LinenVmapped(64)\n", "var = model.init(jax.random.key(0), x)\n", - "print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped\n", + "print(var['params']['VmapToLinen_0']['kernel'].value.shape) # (4, 32, 64) - leading dim 4 got `vmap`ped.\n", "y = model.apply(var, x)\n", "print(y.shape)" ] + }, + { + "cell_type": "markdown", + "id": "920324e4", + "metadata": {}, + "source": [ + " (4, 32, 64)\n", + " (4, 64)" + ] } ], "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,md:myst", + "main_language": "python" + }, "language_info": { "codemirror_mode": { "name": "ipython", diff --git a/docs_nnx/guides/bridge_guide.md b/docs_nnx/guides/bridge_guide.md index 3f243ae2ab..27c22932b8 100644 --- a/docs_nnx/guides/bridge_guide.md +++ b/docs_nnx/guides/bridge_guide.md @@ -1,22 +1,40 @@ -# Use Flax NNX and Linen together +--- +jupytext: + cell_metadata_filter: -all + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- -This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. +# Use Flax NNX and Linen together via `nnx.bridge` -This will be helpful if you: +This guide is designed to assist existing Flax users who want to mix Flax NNX and Flax Linen `Module`s in their codebase. Bridging NNX and Linen code is made possible with the help of the [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) API. This document should enable you to move to and try out Flax NNX at your own pace, and leverage "the best of both worlds". This can be particularly helpful if you: -* Want to migrate your codebase to NNX gradually, one module at a time; -* Have external dependency that already moved to NNX but you haven't, or is still in Linen while you've moved to NNX. +* Want to migrate your codebase to [Flax NNX](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) from [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) gradually, one `Module` at a time; and/or +* Have an external dependency that has already been moved to Flax NNX, but you have not done so. Alternatively, it may still be in Flax Linen while you've moved your code to Flax NNX. -We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different. +You will also learn how to resolve certain caveats of interoperating both Flax Linen and Flax NNX APIs. The guide will also teach you some aspects of how Flax Linen and NNX APIs are fundamentally different. -**Note**: +Table of contents: -This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. +- A sub-`Module` is all you need +- Basics + - Flax Linen to NNX with `nnx.bridge.lazy_init` + - Flax NNX to Linen with `nnx.bridge.ToLinen` +- Handling the JAX PRNG keys +- Flax NNX variable types vs Flax Linen collections +- Partition metadata +- Lifted transformations - go ahead and do it -And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). +**Note**: Since this guide describes how to glue a [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) with a [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module), if you need to _migrate_ an existing Linen `Module` (a.k.a. `nn.Module`) to an NNX `Module`, check out the [Migrate from Haiku to Flax (Linen and NNX)](https://flax.readthedocs.io/en/latest/guides/haiku_to_flax.html) guide. In addition, all [built-in Flax Linen layers](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html) should have [equivalent Flax NNX versions](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). +First, let's import some necessary dependencies: -```python +```{code-cell} ipython3 import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' @@ -29,72 +47,91 @@ from jax.experimental import mesh_utils from typing import * ``` -## Submodule is all you need +## A sub-`Module` is all you need -A Flax model is always a tree of modules - either old Linen modules (`flax.linen.Module`, usually written as `nn.Module`) or NNX modules (`nnx.Module`). +A Flax model is a [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree) of `Module`s - either an old [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) (usually written as `nn.Module`) or a new [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). -An `nnx.bridge` wrapper glues the two types together, in both ways: +The [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) wrapper API enables you to glue these two types of `Module`s together in two ways using: -* `nnx.bridge.ToNNX`: Convert a Linen module to NNX, so that it can be a submodule of another NNX module, or stand alone to be trained in NNX-style training loops. -* `nnx.bridge.ToLinen`: Vice versa, convert a NNX module to Linen. +* [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX): Converts a `flax.linen.Module` to NNX, so that it can be a sub-`Module` of another `flax.nnx.Module`, or a standalone `Module` to be trained in NNX style training loops. +* [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen): The opposite of `nnx.bridge.ToNNX` - it converts a `flax.nnx.Module` to `flax.linen.Module`. -This means you can move in either top-down or bottom-up behavior: convert the whole Linen module to NNX, then gradually move down, or convert all the lower level modules to NNX then move up. +Therefore, you can convert the entire `flax.linen.Module` to Flax NNX, and then gradually “move down” (the “top-down” way), or convert all the lower-level `flax.linen.Module`s to Flax NNX and then “move up” (the “bottom-up” way). ++++ -## The Basics +## Basics -There are two fundamental difference between Linen and NNX modules: +There are two fundamental differences between [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) and [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module): -* **Stateless vs. stateful**: Linen module instances are stateless: variables are returned from a purely functional `.init()` call and managed separately. NNX modules, however, owns its variables as instance attributes. +* **Stateless vs stateful**: + - Flax Linen `Module` instances are stateless: Variables are returned from a purely functional `Module.init()` call and managed separately. + - Flax NNX `Module`s, however, own their variables as instance attributes. -* **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input. +* **Lazy vs eager**: + - Flax Linen `Module`s only allocate space to create variables when they actually see their input. + - In comparison, Flax NNX `Module` instances create their [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) the moment they are instantiated without seeing a sample input. -With that in mind, let's look at how the `nnx.bridge` wrappers tackle the differences. +With that in mind, let's investigate how the [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) wrappers tackle these differences. -### Linen -> NNX +### Basics: Flax Linen to NNX with `nnx.bridge.lazy_init` -Since Linen modules may require an input to create variables, we semi-formally supported lazy initialization in the NNX modules converted from Linen. The Linen variables are created when you give it a sample input. +Since `flax.linen.Module`s may require an input to create variables, the Flax team semi-formally supports lazy initialization in the `flax.nnx.Module`s converted from Flax Linen. The Flax Linen variables are created when you give it a sample input. For you, it's calling [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init) (`nnx.bridge.ToNNX.lazy_init`) where you call `module.init()` in the Flax Linen code. -For you, it's calling `nnx.bridge.lazy_init()` where you call `module.init()` in Linen code. +> **Note:** To inspect all `flax.nnx.Module` variables and state, You can call [`nnx.display`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/visualization.html#flax.nnx.display). -(Note: you can call `nnx.display` upon any NNX module to inspect all its variables and state.) - - -```python +```{code-cell} ipython3 class LinenDot(nn.Module): out_dim: int w_init: Callable[..., Any] = nn.initializers.lecun_normal() @nn.compact def __call__(self, x): - # Linen might need the input shape to create the weight! + # Flax Linen might need the input shape to create the weight! w = self.param('w', self.w_init, (x.shape[-1], self.out_dim)) return x @ w x = jax.random.normal(jax.random.key(42), (4, 32)) -model = bridge.ToNNX(LinenDot(64), - rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen -bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen -y = model(x) # => `y = model.apply(var, x)` in Linen +model = bridge.ToNNX(LinenDot(64), rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen +bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen +y = model(x) # => `y = model.apply(var, x)` in Linen nnx.display(model) # In-place swap your weight array and the model still works! -model.w.value = jax.random.normal(jax.random.key(1), (32, 64)) +model.params['w'].value = jax.random.normal(jax.random.key(1), (32, 64)) assert not jnp.allclose(y, model(x)) ``` + ToNNX( + module=LinenDot( + # attributes + out_dim = 64 + w_init = init + ), + rngs=Rngs( + default=RngStream( + key=RngKey( + value=Array((), dtype=key) overlaying: + [0 0], + tag='default' + ), + count=RngCount( + value=Array(1, dtype=uint32), + tag='default' + ) + ) + ), + linen_collections=('params',), + params={'w': Param( + value=Array(shape=(32, 64), dtype=float32) + )} + ) -
- - - -
- - -`nnx.bridge.lazy_init` also works even if the top-level module is a pure-NNX one, so you can do sub-moduling as you wish: ++++ +The [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init) method also works even if the top-level `Module` is a pure-NNX one, so you can perform "sub-moduling" as you wish: -```python +```{code-cell} ipython3 class NNXOuter(nnx.Module): def __init__(self, out_dim: int, rngs: nnx.Rngs): self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs) @@ -104,121 +141,195 @@ class NNXOuter(nnx.Module): return self.dot(x) + self.b x = jax.random.normal(jax.random.key(42), (4, 32)) -model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line +model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit them into one line too nnx.display(model) ``` + NNXOuter( + dot=ToNNX( + module=LinenDot( + # attributes + out_dim = 64 + w_init = init + ), + rngs=Rngs( + default=RngStream( + key=RngKey( + value=Array((), dtype=key) overlaying: + [0 0], + tag='default' + ), + count=RngCount( + value=Array(1, dtype=uint32), + tag='default' + ) + ) + ), + linen_collections=('params',), + params={'w': Param( + value=Array(shape=(32, 64), dtype=float32) + )} + ), + b=Param( + value=Array(shape=(1, 64), dtype=float32) + ) + ) -
- - - -
- - -The Linen weight is already converted to a typical NNX variable, which is a thin wrapper of the actual JAX array value within. Here, `w` is an `nnx.Param` because it belongs to the `params` collection of `LinenDot` module. ++++ -We will talk more about different collections and types in the [NNX Variable <-> Linen Collections](#variable-types-vs-collections) section. Right now, just know that they are converted to NNX variables like native ones. +The Flax Linen weight is already converted to a typical Flax NNX variable ([`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)), which is a thin wrapper of the actual [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) value within. Here, `w` is an [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) because it belongs to the `params` collection of `LinenDot` `flax.linen.Module`. +Different collections and types are covered in more detail in the [NNX Variable <-> Linen Collections](#variable-types-vs-collections) section. Right now, you just need to know that they are converted to Flax `nnx.Variable`s like native ones. -```python -assert isinstance(model.dot.w, nnx.Param) -assert isinstance(model.dot.w.value, jax.Array) +```{code-cell} ipython3 +assert isinstance(model.dot.params['w'], nnx.Param) +assert isinstance(model.dot.params['w'].value, jax.Array) ``` -If you create this model witout using `nnx.bridge.lazy_init`, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside `ToNNX`) will not. - +If you create this model without using [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init), the Flax [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) defined outside will be initialized as usual, but the Flax Linen part (that is wrapped inside of [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX)) will not. -```python +```{code-cell} ipython3 partial_model = NNXOuter(64, rngs=nnx.Rngs(0)) nnx.display(partial_model) ``` + NNXOuter( + dot=ToNNX( + module=LinenDot( + # attributes + out_dim = 64 + w_init = init + ), + rngs=Rngs( + default=RngStream( + key=RngKey( + value=Array((), dtype=key) overlaying: + [0 0], + tag='default' + ), + count=RngCount( + value=Array(1, dtype=uint32), + tag='default' + ) + ) + ), + linen_collections=() + ), + b=Param( + value=Array(shape=(1, 64), dtype=float32) + ) + ) -
- - - -
- - - -```python +```{code-cell} ipython3 full_model = bridge.lazy_init(partial_model, x) nnx.display(full_model) ``` + NNXOuter( + dot=ToNNX( + module=LinenDot( + # attributes + out_dim = 64 + w_init = init + ), + rngs=Rngs( + default=RngStream( + key=RngKey( + value=Array((), dtype=key) overlaying: + [0 0], + tag='default' + ), + count=RngCount( + value=Array(1, dtype=uint32), + tag='default' + ) + ) + ), + linen_collections=('params',), + params={'w': Param( + value=Array(shape=(32, 64), dtype=float32) + )} + ), + b=Param( + value=Array(shape=(1, 64), dtype=float32) + ) + ) -
- - - -
- - -### NNX -> Linen ++++ -To convert an NNX module to Linen, you should forward your creation arguments to `bridge.ToLinen` and let it handle the actual creation process. +### Basics: Flax NNX to Linen `nnx.bridge.ToLinen` -This is because NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical `init` and `apply` process involves multiple creation of them. So `bridge.to_linen` will handle the actual module creation and make sure no memory is allocated twice. +To convert a `flax.nnx.Module` to Flax Linen, you should forward your creation arguments to [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) and let it handle the actual creation process. +This is because: +- The `flax.nnx.Module` instance initializes all its variables eagerly when it is created, which consumes memory and compute. +- On the other hand, `flax.linen.Module`s are stateless, and the typical `init` and `apply` process involves multiple creation of them. Therefore, [`nnx.bridge.to_linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.to_linen) will handle the actual `Module` creation and make sure no memory is allocated twice. -```python +```{code-cell} ipython3 class NNXDot(nnx.Module): def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs): - self.w = nnx.Param(nnx.initializers.lecun_normal()( - rngs.params(), (in_dim, out_dim))) + self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (in_dim, out_dim))) def __call__(self, x: jax.Array): return x @ self.w x = jax.random.normal(jax.random.key(42), (4, 32)) -# Pass in the arguments, not an actual module -model = bridge.to_linen(NNXDot, 32, out_dim=64) +model = bridge.to_linen(NNXDot, 32, out_dim=64) # <- Pass in the arguments, not an actual module variables = model.init(jax.random.key(0), x) y = model.apply(variables, x) print(list(variables.keys())) -print(variables['params']['w'].shape) # => (32, 64) -print(y.shape) # => (4, 64) - +print(variables['params']['w'].value.shape) # => (32, 64) +print(y.shape) # => (4, 64) ``` ['nnx', 'params'] (32, 64) (4, 64) ++++ -Note that `ToLinen` modules need to track an extra variable collection - `nnx` - for the static metadata of the underlying NNX module. +Note that [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) `Module`s need to track an extra variable collection - `nnx` - for the static metadata of the underlying `nnx.Module`. - -```python +```{code-cell} ipython3 # This new field stores the static data that defines the underlying `NNXDot` print(type(variables['nnx']['graphdef'])) # => `nnx.graph.NodeDef` ``` ++++ -`bridge.to_linen` is actually a convenience wrapper around the Linen module `bridge.ToLinen`. Most likely you won't need to use `ToLinen` directly at all, unless you are using one of the built-in arguments of `ToLinen`. For example, if your NNX module doesn't want to be initialized with RNG handling: - +[`nnx.bridge.to_linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.to_linen) is actually a convenience wrapper around the Flax Linen Module [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen). Most likely you won't need to use `nnx.bridge.ToLinen directly at all, unless you are using one of the built-in arguments of `nnx.bridge.ToLinen`. For example, if your `nnx.Module` doesn't want to be initialized with PRNG handling: -```python +```{code-cell} ipython3 class NNXAddConstant(nnx.Module): def __init__(self): self.constant = nnx.Variable(jnp.array(1)) def __call__(self, x): return x + self.constant -# You have to use `skip_rng=True` because this module's `__init__` don't -# take `rng` as argument +# You have to use `skip_rng=True` because your module `__init__` don't take `rng` as an argument. model = bridge.ToLinen(NNXAddConstant, skip_rng=True) y, var = model.init_with_output(jax.random.key(0), x) ``` -Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module. +You may notice that you need to an additional `.value` to access this Flax `w` param. This is because all Flax NNX variables will be wrapped with an [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) class, which will allow it to be annotated with various information, such as its partitioning. This was translated into an equivalent `nnx.bridge.NNXMeta` wrapper. +If you use [Partition metadata in Flax Linen](https://flax-linen.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html), you can learn more about how that works in Flax NNX in [Partition metadata section](#partition-metadata) below. -```python +```{code-cell} ipython3 +print(type(variables['params']['w'])) # => nnx.bridge.NNXMeta +print(type(variables['params']['w'].value)) # => jax.Array +``` + + + + ++++ + +Similar to [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX), you can use [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) to create a sub-`Module` of another `flax.linen.Module`. + +```{code-cell} ipython3 class LinenOuter(nn.Module): out_dim: int @nn.compact @@ -230,72 +341,73 @@ class LinenOuter(nn.Module): x = jax.random.normal(jax.random.key(42), (4, 32)) model = LinenOuter(out_dim=64) y, variables = model.init_with_output(jax.random.key(0), x) -w, b = variables['params']['ToLinen_0']['w'], variables['params']['b'] +w, b = variables['params']['ToLinen_0']['w'].value, variables['params']['b'] print(w.shape, b.shape, y.shape) ``` (32, 64) (1, 64) (4, 64) ++++ -## Handling RNG keys +## Handling the JAX PRNG keys -All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys. +All Flax `Module`s - in [Linen](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) or [NNX](https://flax.readthedocs.io/en/latest/guides/randomness.html) - can automatically handle the JAX [pseudorandom number generator (PRNG)](https://jax.readthedocs.io/en/latest/random-numbers.html) keys for variable creation and random layers like dropouts. However, the specific logics of PRNG key splitting are different, so you cannot generate the same params between Linen and NNX `Module`s, even if you pass in the same keys. Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves. -### Linen to NNX +> **Note:** To refresh your memory of PRNG key handling, review [JAX PRNG 101](https://jax.readthedocs.io/en/latest/random-numbers.html), [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng), [Flax NNX Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html), and [Flax Linen Randomness and PRNGs](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html). -If you convert a Linen module to NNX, you enjoy the stateful benefit and don't need to pass in extra RNG keys on every module call. You can use always `nnx.reseed` to reset the RNG state within. +### PRNG keys: Flax Linen to NNX - Enjoy the stateful benefits! +If you convert a Flax Linen `Module` to NNX, you can enjoy the stateful benefits and don't need to pass in extra PRNG keys on every `nnx.Module` call. And you can use always [`nnx.reseed`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.reseed) to reset the PRNG state within. -```python +```{code-cell} ipython3 x = jax.random.normal(jax.random.key(42), (4, 32)) model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0)) -# We don't really need to call lazy_init because no extra params were created here, -# but it's a good practice to always add this line. -bridge.lazy_init(model, x) +bridge.lazy_init(model, x) # You don't really need this because no extra params were created here, + # but it's a good practice to always add this line. y1, y2 = model(x), model(x) assert not jnp.allclose(y1, y2) # Two runs yield different outputs! -# Reset the dropout RNG seed, so that next model run will be the same as the first. +# Reset the dropout PRNG seed, so that the next model run will be the same as the first. nnx.reseed(model, dropout=0) assert jnp.allclose(y1, model(x)) ``` -### NNX to Linen - -If you convert an NNX module to Linen, the underlying NNX module's RNG states will still be part of the top-level `variables`. On the other hand, Linen `apply()` call accepts different RNG keys on each call, which resets the internal Linen environment and allow different random data to be generated. +### PRNG keys: Flax NNX to Linen - Two handling style options -Now, it really depends on whether your underlying NNX module generates new random data from its RNG state, or from the passed-in argument. Fortunately, `nnx.Dropout` supports both - using passed-in keys if there is any, and use its own RNG state if not. +If you convert a Flax NNX `Module` to Linen, the underlying `flax.nnx.Module's PRNG states will still be part of the top-level variables. On the other hand, the `flax.linen.Module.apply()` call accepts different PRNG keys on each call, which _resets the internal Flax Linen environment and allows different random data to be generated_. -And this leaves you with two style options of handling the RNG keys: +Now, it really depends on whether your underlying NNX module generates new random data from its PRNG state, or from the passed-in argument. Fortunately, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) supports both - using passed-in keys if there is any, and using its own PRNG state if not. -* The NNX style (recommended): Let the underlying NNX state manage the RNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs `ToLinen`. +And this leaves you with two style options of handling the PRNG keys: -* The Linen style: Just pass different RNG keys for every `apply()` call. +* The Flax NNX style (recommended): Let the underlying NNX state manage the PRNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen). +* The Flax Linen style: Just pass different PRNG keys for every `apply()` call. +> **Note:** Don't forget - there are [Flax NNX Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html), and [Flax Linen Randomness and PRNGs](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) tutorials that can help you understand PRNG handling in Flax. -```python +```{code-cell} ipython3 x = jax.random.normal(jax.random.key(42), (4, 32)) model = bridge.to_linen(nnx.Dropout, rate=0.5) variables = model.init({'dropout': jax.random.key(0)}, x) -# The NNX RNG state was stored inside `variables` +# The Flax NNX PRNG state was stored inside `variables`. print('The RNG key in state:', variables['RngKey']['rngs']['dropout']['key'].value) print('Number of key splits:', variables['RngCount']['rngs']['dropout']['count'].value) -# NNX style: Must set `RngCount` as mutable and update the variables after every `apply` +# Flax NNX style: Must set `RngCount` as mutable and update the variables after every `apply`. y1, updates = model.apply(variables, x, mutable=['RngCount']) variables |= updates y2, updates = model.apply(variables, x, mutable=['RngCount']) variables |= updates print('Number of key splits after y2:', variables['RngCount']['rngs']['dropout']['count'].value) -assert not jnp.allclose(y1, y2) # Every call yields different output! +assert not jnp.allclose(y1, y2) # Every call yields a different output! -# Linen style: Just pass different RNG keys for every `apply()` call. +# Flax Linen style: Just pass different PRNG keys for every `apply()` call. y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)}) y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)}) -assert not jnp.allclose(y3, y4) # Every call yields different output! +assert not jnp.allclose(y3, y4) # Every call yields a different output! y5 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)}) assert jnp.allclose(y3, y5) # When you use same top-level RNG, outputs are same ``` @@ -305,23 +417,23 @@ assert jnp.allclose(y3, y5) # When you use same top-level RNG, outputs are Number of key splits: 0 Number of key splits after y2: 2 ++++ -## NNX variable types vs. Linen collections - -When you want to group some variables as one category, in Linen you use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types. +## Flax NNX variable types vs Flax Linen collections -Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that `ToNNX` and `ToLinen` can do the conversion automatically. +When you want to group certain variables in one category, in Flax Linen you use different collections. In Flax NNX, because all variables shall be top-level Python attributes, you use different variable types. -Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of NNX variable type and Linen collection names using `nnx.register_variable_name_type_pair`. +Therefore, when mixing Flax Linen and NNX `Module`s, Flax must know the 1-to-1 mapping between Flax Linen collections and Flax NNX variable types, so that [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) and [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) can do the conversion automatically. -### Linen to NNX +Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of Flax NNX variable types and Flax Linen collection names using [`nnx.register_variable_name_type_pair`](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/bridge/variables.html). -For any collection of your Linen module, `ToNNX` will convert all its endpoint arrays (aka. leaves) to a subtype of `nnx.Variable`, either from registry or automatically created on-the-fly. +### Variables and collections: Flax Linen to NNX -(However, we still keep the whole collection as one class attribute, because Linen modules may have duplicated names over different collections.) +For any collection of your Linen module, [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) will convert all its endpoint arrays (a.k.a. [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree) [leaves](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#mistaking-pytree-nodes-for-leaves)) to a subtype of [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), either from registry or automatically created on-the-fly. +> **Note:** However, you still keep the entire collection(s) as one class attribute, because `flax.linen.Module`s may have duplicated names over different collections. -```python +```{code-cell} ipython3 class LinenMultiCollections(nn.Module): out_dim: int def setup(self): @@ -338,13 +450,13 @@ class LinenMultiCollections(nn.Module): x = jax.random.normal(jax.random.key(42), (2, 4)) model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x) -print(model.w) # Of type `nnx.Param` - note this is still under attribute `params` -print(model.b) # Of type `nnx.Param` -print(model.count) # Of type `counter` - auto-created type from the collection name -print(type(model.count)) +print(model.params['w']) # Of type `nnx.Param` - note this is still under attribute `params` +print(model.params['b']) # Of type `nnx.Param` +print(model.counter['count']) # Of type `counter` - an auto-created dummy type from the name "counter" +print(type(model.counter['count'])) -y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger -print(model.dot_sum) # Of type `nnx.Intermediates` +y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger +print(model.intermediates['dot_sum']) # Of type `nnx.Intermediates` ``` Param( @@ -364,35 +476,32 @@ print(model.dot_sum) # Of type `nnx.Intermediates` value=Array(6.932987, dtype=float32) ),) ++++ -You can quickly separate different types of NNX variables apart using `nnx.split`. +You can quickly separate different types of Flax NNX variables apart using [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). This can be handy when you only want to set certain variables as trainable. -This can be handy when you only want to set some variables as trainable. - - -```python -# Separate variables of different types with nnx.split -CountType = type(model.count) +```{code-cell} ipython3 +# Separate variables of different types with `nnx.split`. +CountType = type(model.counter['count']) static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...) -print('All Params:', list(params.keys())) -print('All Counters:', list(counter.keys())) +print('All Params:', list(params['params'].keys())) +print('All Counters:', list(counter['counter'].keys())) print('All the rest (intermediates and RNG keys):', list(the_rest.keys())) -model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time -y = model(x, mutable=True) # still works! +model = nnx.merge(static, params, counter, the_rest) # You can merge them (back) anytime. ``` All Params: ['b', 'w'] All Counters: ['count'] - All the rest (intermediates and RNG keys): ['dot_sum', 'rngs'] - + All the rest (intermediates and RNG keys): ['intermediates', 'rngs'] -### NNX to Linen ++++ -If you define custom NNX variable types, you should register their names with `nnx.register_variable_name_type_pair` so that they go to the desired collections. +### Variables and collections: Flax NNX to Linen +If you define custom Flax NNX variable types, you should register their names with [`nnx.register_variable_name_type_pair`](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/bridge/variables.html) so that they go to the desired collections. -```python +```{code-cell} ipython3 class Count(nnx.Variable): pass nnx.register_variable_name_type_pair('counts', Count, overwrite=True) @@ -415,43 +524,45 @@ print(var['params']) ``` All Linen collections: ['nnx', 'LoRAParam', 'params', 'counts'] - {'w': Array([[ 0.2916921 , 0.22780475, 0.06553137], + {'w': NNXMeta(var_type=, value=Array([[ 0.2916921 , 0.22780475, 0.06553137], [ 0.17487915, -0.34043145, 0.24764155], [ 0.6420431 , 0.6220095 , -0.44769976], - [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)} + [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32), metadata={'get_value_hooks': (), 'set_value_hooks': (), 'create_value_hooks': (), 'add_axis_hooks': (), 'remove_axis_hooks': ()})} ++++ ## Partition metadata -Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded. +Flax uses a metadata wrapper box over the raw [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) to annotate how a variable should be sharded. -In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too. +In Flax Linen, this is an optional feature that is triggered by using [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) on initializers. In Flax NNX, since all Flax NNX variables are wrapped by [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) class anyway, that class will hold the sharding annotations too. -The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX). +> **Note:** If you are new to `jax.Array`s and _data sharding_, go to [Key concepts](https://jax.readthedocs.io/en/latest/key-concepts.html#array-devices-and-sharding) and [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html#sharded-computation) on the JAX documentation site. -### Linen to NNX +Both [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) and [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) will automatically convert the sharding annotations if you use the built-in annotation methods, such as [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) or [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning). -Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to `nnx.Variable`s that wraps the true JAX array within. +> **Note:** To get more familiarized with sharding metadata with Flax and JAX, refer to Flax NNX’s [Scale up](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide, JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html#sharded-computation), and the Flax Linen [Scale up](https://flax-linen.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) guide. -If you use `nn.with_partitioning` to annotate your Linen module's variables, the annotation will be converted to a `.sharding` field in the corresponding `nnx.Variable`. +### Partition metadata: Flax Linen to NNX -You can then use `nnx.with_sharding_constraint` to explicitly put the arrays into the annotated partitions within a `jax.jit`-compiled function, to initialize the whole model with every array at the right sharding. +Even if you are not using any partition metadata in your Flax Linen `Module`, the variable [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) will be converted to [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) that wraps the true `jax.Array` within. +If you use [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) to annotate your Flax Linen `Module` variables, the annotation will be converted to the `.sharding` field in the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable). -```python +You can then use [`nnx.with_sharding_constraint`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_sharding_constraint) to explicitly put the arrays into the annotated partitions within a [`jax.jit`](https://jax.readthedocs.io/en/latest/jit-compilation.html)-compiled function, to initialize the whole model with every array at the right sharding. + +```{code-cell} ipython3 class LinenDotWithPartitioning(nn.Module): out_dim: int @nn.compact def __call__(self, x): - w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), - ('in', 'out')), + w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), (x.shape[-1], self.out_dim)) return x @ w @nnx.jit def create_sharded_nnx_module(x): - model = bridge.lazy_init( - bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x) + model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x) state = nnx.state(model) sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) nnx.update(model, sharded_state) @@ -459,15 +570,14 @@ def create_sharded_nnx_module(x): print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...') -mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), - axis_names=('in', 'out')) +mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), axis_names=('in', 'out')) x = jax.random.normal(jax.random.key(42), (4, 32)) with mesh: model = create_sharded_nnx_module(x) -print(type(model.w)) # `nnx.Param` -print(model.w.sharding) # The partition annotation attached with `w` -print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh +print(type(model.params['w'])) # `nnx.Param` +print(model.params['w'].sharding) # The partition annotation attached with the weight `w` +print(model.params['w'].value.sharding) # The underlying JAX array is sharded across the 2x4 mesh ``` We have 8 fake JAX devices now to partition this model... @@ -475,17 +585,17 @@ print(model.w.value.sharding) # The underlying JAX array is sharded across the ('in', 'out') GSPMDSharding({devices=[2,4]<=[8]}) ++++ -### NNX to Linen - -If you are not using any metadata feature of the `nnx.Variable` (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don't need to worry about it. +### Partition metadata: Flax NNX to Linen -But if you did add sharding annotations to your NNX variables, `ToLinen` will convert them to a default Linen partition metadata class called `bridge.NNXMeta`, retaining all the metadata you put into the NNX variable. +Since all Flax NNX variables are wrapped with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) box, the converted Flax Linen module will have all variables boxed too. There is a default partition metadata class called [`flax.nnx.bridge.NNXMeta`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.NNXMeta) for storing these converted NNX variables. -Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the raw JAX array tree. +[`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) will automatically shard the array with the annotation if it is called within a [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) context, so you don't need to do `with_sharding_constraint` yourself. +Similar to any Flax Linen metadata wrappers, you can use [`flax.linen.meta.unbox`](https://github.com/google/flax/blob/5d31452889b8d106d7c722b5eaac14cb9784fec2/flax/core/meta.py#L160) to get the raw [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree). -```python +```{code-cell} ipython3 class NNXDotWithParititioning(nnx.Module): def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs): init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out')) @@ -494,48 +604,37 @@ class NNXDotWithParititioning(nnx.Module): return x @ self.w x = jax.random.normal(jax.random.key(42), (4, 32)) - -@jax.jit -def create_sharded_variables(key, x): - model = bridge.to_linen(NNXDotWithParititioning, 32, 64) - variables = model.init(key, x) - # A `NNXMeta` wrapper of the underlying `nnx.Param` - assert type(variables['params']['w']) == bridge.NNXMeta - # The annotation coming from the `nnx.Param` => (in, out) - assert variables['params']['w'].metadata['sharding'] == ('in', 'out') - - unboxed_variables = nn.unbox(variables) - variable_pspecs = nn.get_partition_spec(variables) - assert isinstance(unboxed_variables['params']['w'], jax.Array) - assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out') - - sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint, - nn.unbox(variables), - nn.get_partition_spec(variables)) - return sharded_vars +model = bridge.to_linen(NNXDotWithParititioning, 32, 64) with mesh: - variables = create_sharded_variables(jax.random.key(0), x) + variables = jax.jit(model.init)(jax.random.key(0), x) -# The underlying JAX array is sharded across the 2x4 mesh -print(variables['params']['w'].sharding) +print(type(variables['params']['w'])) # A `NNXMeta` wrapper of the underlying `nnx.Param` +print(variables['params']['w'].metadata['sharding']) # The annotation coming from the `nnx.Param` +print(variables['params']['w'].value.sharding) # The underlying JAX array is sharded across the 2x4 mesh + +unboxed = nn.unbox(variables) +print(type(unboxed['params']['w'])) # The raw jax.Array ``` + + ('in', 'out') GSPMDSharding({devices=[2,4]<=[8]}) + ++++ -## Lifted transforms - -In general, if you want to apply Linen/NNX-style lifted transforms upon an `nnx.bridge`-converted module, just go ahead and do it in the usual Linen/NNX syntax. +## Lifted transformations - go ahead and do it -For Linen-style transforms, note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases) +In general, if you want to apply [Flax Linen-](https://flax-linen.readthedocs.io/en/latest/developer_notes/lift.html) or [Flax NNX style lifted transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) on an [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html)-converted `Module`, just go ahead and do it in the usual Flax Linen or NNX syntax. -### Linen to NNX +For [Flax Linen style transforms](https://flax-linen.readthedocs.io/en/latest/developer_notes/lift.html), note that [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) is the top-level `Module` class, so you may want to just use it as the first argument of your transforms (which needs to be a `flax.linen.Module` class in most cases). -NNX style lifted transforms are similar to JAX transforms, and they work on functions. +### Lifted transforms: Flax Linen to NNX +[Flax NNX style lifted transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) are similar to [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations), and they too work on functions. -```python +```{code-cell} ipython3 class NNXVmapped(nnx.Module): def __init__(self, out_dim: int, vmap_axis_size: int, rngs: nnx.Rngs): self.linen_dot = nnx.bridge.ToNNX(nn.Dense(out_dim, use_bias=False), rngs=rngs) @@ -553,7 +652,7 @@ class NNXVmapped(nnx.Module): x = jax.random.normal(jax.random.key(0), (4, 32)) model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x) -print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped +print(model.linen_dot.params['kernel'].shape) # (4, 32, 64) - first axis with dim 4 got `vmap`ped. y = model(x) print(y.shape) ``` @@ -561,15 +660,13 @@ print(y.shape) (4, 32, 64) (4, 64) ++++ -### NNX to Linen +### Lifted transforms: Flax NNX to Linen -Note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases). +As mentioned before, [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) is the top-level `Module` class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases). And, since `nnx.bridge.ToLinen` introduces this extra `nnx` collection, you need to mark it when using the axis-changing transforms ([`flax.linen.vmap`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.vmap), [`flax.linen.scan`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.scan), and so on) to make sure they are passed inside. -Also, since `bridge.ToLinen` introduced this extra `nnx` collection, you need to mark it when using the axis-changing transforms (`linen.vmap`, `linen.scan`, etc) to make sure they are passed inside. - - -```python +```{code-cell} ipython3 class LinenVmapped(nn.Module): dout: int @nn.compact @@ -581,11 +678,10 @@ class LinenVmapped(nn.Module): x = jax.random.normal(jax.random.key(42), (4, 32)) model = LinenVmapped(64) var = model.init(jax.random.key(0), x) -print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped +print(var['params']['VmapToLinen_0']['kernel'].value.shape) # (4, 32, 64) - leading dim 4 got `vmap`ped. y = model.apply(var, x) print(y.shape) ``` (4, 32, 64) (4, 64) -