diff --git a/docs/guides/flax_basics.ipynb b/docs/guides/flax_basics.ipynb index 92acef3cd1..771b26d9aa 100644 --- a/docs/guides/flax_basics.ipynb +++ b/docs/guides/flax_basics.ipynb @@ -1,1072 +1,1042 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "yf-nWLh0naJi" - }, - "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_basics.ipynb)\n", - "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_basics.ipynb)\n", - "\n", - "# Flax Basics\n", - "\n", - "This notebook will walk you through the following workflow:\n", - "\n", - "* Instantiating a model from Flax built-in layers or third-party models.\n", - "* Initializing parameters of the model and manually written training.\n", - "* Using optimizers provided by Flax to ease training.\n", - "* Serialization of parameters and other objects.\n", - "* Creating your own models and managing state." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KyANAaZtbs86" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "yf-nWLh0naJi" + }, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_basics.ipynb)\n", + "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_basics.ipynb)\n", + "\n", + "# Flax Basics\n", + "\n", + "This notebook will walk you through the following workflow:\n", + "\n", + "* Instantiating a model from Flax built-in layers or third-party models.\n", + "* Initializing parameters of the model and manually written training.\n", + "* Using optimizers provided by Flax to ease training.\n", + "* Serialization of parameters and other objects.\n", + "* Creating your own models and managing state." + ] }, - "source": [ - "## Setting up our environment\n", - "\n", - "Here we provide the code needed to set up the environment for our notebook." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "qdrEVv9tinJn", - "outputId": "e30aa464-fa52-4f35-df96-716c68a4b3ee", - "tags": [ - "skip-execution" + { + "cell_type": "markdown", + "metadata": { + "id": "KyANAaZtbs86" + }, + "source": [ + "## Setting up our environment\n", + "\n", + "Here we provide the code needed to set up the environment for our notebook." ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n", - "\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n" + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "qdrEVv9tinJn", + "outputId": "e30aa464-fa52-4f35-df96-716c68a4b3ee", + "tags": [ + "skip-execution" ] - } - ], - "source": [ - "# Install the latest JAXlib version.\n", - "!pip install --upgrade -q pip jax jaxlib\n", - "# Install Flax at head:\n", - "!pip install --upgrade -q git+https://github.com/google/flax.git" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "kN6bZDaReZO2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n", + "\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n" + ] + } + ], + "source": [ + "# Install the latest JAXlib version.\n", + "!pip install --upgrade -q pip jax jaxlib\n", + "# Install Flax at head:\n", + "!pip install --upgrade -q git+https://github.com/google/flax.git" + ] }, - "outputs": [], - "source": [ - "import jax\n", - "from typing import Any, Callable, Sequence\n", - "from jax import lax, random, numpy as jnp\n", - "from flax.core import freeze, unfreeze\n", - "from flax import linen as nn" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pCCwAbOLiscA" + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "kN6bZDaReZO2" + }, + "outputs": [], + "source": [ + "import jax\n", + "from typing import Any, Callable, Sequence\n", + "from jax import lax, random, numpy as jnp\n", + "from flax.core import freeze, unfreeze\n", + "from flax import linen as nn" + ] }, - "source": [ - "## Linear regression with Flax\n", - "\n", - "In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done.\n", - "\n", - "A dense layer is a layer that has a kernel parameter $W\\in\\mathcal{M}_{m,n}(\\mathbb{R})$ where $m$ is the number of features as an output of the model, and $n$ the dimensionality of the input, and a bias parameter $b\\in\\mathbb{R}^m$. The dense layers returns $Wx+b$ from an input $x\\in\\mathbb{R}^n$.\n", - "\n", - "This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`)." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "zWX2zEtphT4Y" + { + "cell_type": "markdown", + "metadata": { + "id": "pCCwAbOLiscA" + }, + "source": [ + "## Linear regression with Flax\n", + "\n", + "In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done.\n", + "\n", + "A dense layer is a layer that has a kernel parameter $W\\in\\mathcal{M}_{m,n}(\\mathbb{R})$ where $m$ is the number of features as an output of the model, and $n$ the dimensionality of the input, and a bias parameter $b\\in\\mathbb{R}^m$. The dense layers returns $Wx+b$ from an input $x\\in\\mathbb{R}^n$.\n", + "\n", + "This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`)." + ] }, - "outputs": [], - "source": [ - "# We create one dense layer instance (taking 'features' parameter as input)\n", - "model = nn.Dense(features=5)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UmzP1QoQYAAN" + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "zWX2zEtphT4Y" + }, + "outputs": [], + "source": [ + "# We create one dense layer instance (taking 'features' parameter as input)\n", + "model = nn.Dense(features=5)" + ] }, - "source": [ - "Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class.\n", - "\n", - "### Model parameters & initialization\n", - "\n", - "Parameters are not stored with the models themselves. You need to initialize parameters by calling the `init` function, using a PRNGKey and dummy input data." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "K529lhzeYtl8", - "outputId": "06feb9d2-db50-4f41-c169-6df4336f43a5" + { + "cell_type": "markdown", + "metadata": { + "id": "UmzP1QoQYAAN" + }, + "source": [ + "Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class.\n", + "\n", + "### Model parameters & initialization\n", + "\n", + "Parameters are not stored with the models themselves. You need to initialize parameters by calling the `init` function, using a PRNGKey and dummy input data." + ] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "K529lhzeYtl8", + "outputId": "06feb9d2-db50-4f41-c169-6df4336f43a5" }, - { - "data": { - "text/plain": [ - "FrozenDict({\n", - " params: {\n", - " bias: (5,),\n", - " kernel: (10, 5),\n", - " },\n", - "})" + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, - "execution_count": 4, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "key1, key2 = random.split(random.PRNGKey(0))\n", - "x = random.normal(key1, (10,)) # Dummy input data\n", - "params = model.init(key2, x) # Initialization call\n", - "jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NH7Y9xMEewmO" + { + "data": { + "text/plain": [ + "FrozenDict({\n", + " params: {\n", + " bias: (5,),\n", + " kernel: (10, 5),\n", + " },\n", + "})" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "key1, key2 = random.split(random.PRNGKey(0))\n", + "x = random.normal(key1, (10,)) # Dummy input data\n", + "params = model.init(key2, x) # Initialization call\n", + "jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes" + ] }, - "source": [ - "*Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.*\n", - "\n", - "The result is what we expect: bias and kernel parameters of the correct size. Under the hood:\n", - "\n", - "* The dummy input data `x` is used to trigger shape inference: we only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the kernel.\n", - "* The random PRNG key is used to trigger the initialization functions (those have default values provided by the module here).\n", - "* Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`.\n", - "* The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3yL9mKk7naJn" + { + "cell_type": "markdown", + "metadata": { + "id": "NH7Y9xMEewmO" + }, + "source": [ + "*Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.*\n", + "\n", + "The result is what we expect: bias and kernel parameters of the correct size. Under the hood:\n", + "\n", + "* The dummy input data `x` is used to trigger shape inference: we only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the kernel.\n", + "* The random PRNG key is used to trigger the initialization functions (those have default values provided by the module here).\n", + "* Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`.\n", + "* The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`." + ] }, - "source": [ - "The output shows that the parameters are stored in a `FrozenDict` instance, which helps deal with the functional nature of JAX by preventing any mutation of the underlying dict and making the user aware of it. Read more about it in the [`flax.core.frozen_dict.FrozenDict` API docs](https://flax.readthedocs.io/en/latest/api_reference/flax.core.frozen_dict.html#flax.core.frozen_dict.FrozenDict).\n", - "\n", - "As a consequence, the following doesn't work:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "HtOFWeiynaJo", - "outputId": "689b4230-2a3d-4823-d103-2858e6debc4d" + { + "cell_type": "markdown", + "metadata": { + "id": "M1qo9M3_naJo" + }, + "source": [ + "To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input:" + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: FrozenDict is immutable.\n" - ] - } - ], - "source": [ - "try:\n", - " params['new_key'] = jnp.ones((2,2))\n", - "except ValueError as e:\n", - " print(\"Error: \", e)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "M1qo9M3_naJo" + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "J8ietJecWiuK", + "outputId": "7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([-0.7358944, 1.3583755, -0.7976872, 0.8168598, 0.6297793], dtype=float32)" + ] + }, + "execution_count": 6, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "model.apply(params, x)" + ] }, - "source": [ - "To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "J8ietJecWiuK", - "outputId": "7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae" + { + "cell_type": "markdown", + "metadata": { + "id": "lVsjgYzuSBGL" + }, + "source": [ + "### Gradient descent\n", + "\n", + "If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\\{(x_i,y_i), i\\in \\{1,\\ldots, k\\}, x_i\\in\\mathbb{R}^n,y_i\\in\\mathbb{R}^m\\}$, we try to find a set of parameters $W\\in \\mathcal{M}_{m,n}(\\mathbb{R}), b\\in\\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:\n", + "\n", + "$$\\mathcal{L}(W,b)\\rightarrow\\frac{1}{k}\\sum_{i=1}^{k} \\frac{1}{2}\\|y_i-f_{W,b}(x_i)\\|^2_2$$\n", + "\n", + "Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example." + ] }, - "outputs": [ - { - "data": { - "text/plain": [ - "DeviceArray([-0.7358944, 1.3583755, -0.7976872, 0.8168598, 0.6297793], dtype=float32)" + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "bFIiMnL4dl-e", + "outputId": "6eae59dc-0632-4f53-eac8-c22a7c646a52" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x shape: (20, 10) ; y shape: (20, 5)\n" ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "model.apply(params, x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lVsjgYzuSBGL" + } + ], + "source": [ + "# Set problem dimensions.\n", + "n_samples = 20\n", + "x_dim = 10\n", + "y_dim = 5\n", + "\n", + "# Generate random ground truth W and b.\n", + "key = random.PRNGKey(0)\n", + "k1, k2 = random.split(key)\n", + "W = random.normal(k1, (x_dim, y_dim))\n", + "b = random.normal(k2, (y_dim,))\n", + "# Store the parameters in a FrozenDict pytree.\n", + "true_params = freeze({'params': {'bias': b, 'kernel': W}})\n", + "\n", + "# Generate samples with additional noise.\n", + "key_sample, key_noise = random.split(k1)\n", + "x_samples = random.normal(key_sample, (n_samples, x_dim))\n", + "y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))\n", + "print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)" + ] }, - "source": [ - "### Gradient descent\n", - "\n", - "If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\\{(x_i,y_i), i\\in \\{1,\\ldots, k\\}, x_i\\in\\mathbb{R}^n,y_i\\in\\mathbb{R}^m\\}$, we try to find a set of parameters $W\\in \\mathcal{M}_{m,n}(\\mathbb{R}), b\\in\\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:\n", - "\n", - "$$\\mathcal{L}(W,b)\\rightarrow\\frac{1}{k}\\sum_{i=1}^{k} \\frac{1}{2}\\|y_i-f_{W,b}(x_i)\\|^2_2$$\n", - "\n", - "Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "bFIiMnL4dl-e", - "outputId": "6eae59dc-0632-4f53-eac8-c22a7c646a52" + { + "cell_type": "markdown", + "metadata": { + "id": "ZHkioicCiUbx" + }, + "source": [ + "We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees))." + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x shape: (20, 10) ; y shape: (20, 5)\n" - ] - } - ], - "source": [ - "# Set problem dimensions.\n", - "n_samples = 20\n", - "x_dim = 10\n", - "y_dim = 5\n", - "\n", - "# Generate random ground truth W and b.\n", - "key = random.PRNGKey(0)\n", - "k1, k2 = random.split(key)\n", - "W = random.normal(k1, (x_dim, y_dim))\n", - "b = random.normal(k2, (y_dim,))\n", - "# Store the parameters in a FrozenDict pytree.\n", - "true_params = freeze({'params': {'bias': b, 'kernel': W}})\n", - "\n", - "# Generate samples with additional noise.\n", - "key_sample, key_noise = random.split(k1)\n", - "x_samples = random.normal(key_sample, (n_samples, x_dim))\n", - "y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))\n", - "print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZHkioicCiUbx" + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "JqJaVc7BeNyT" + }, + "outputs": [], + "source": [ + "# Same as JAX version but using model.apply().\n", + "@jax.jit\n", + "def mse(params, x_batched, y_batched):\n", + " # Define the squared loss for a single pair (x,y)\n", + " def squared_error(x, y):\n", + " pred = model.apply(params, x)\n", + " return jnp.inner(y-pred, y-pred) / 2.0\n", + " # Vectorize the previous to compute the average of the loss on all samples.\n", + " return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)" + ] }, - "source": [ - "We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees))." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "JqJaVc7BeNyT" + { + "cell_type": "markdown", + "metadata": { + "id": "wGKru__mi15v" + }, + "source": [ + "And finally perform the gradient descent." + ] }, - "outputs": [], - "source": [ - "# Same as JAX version but using model.apply().\n", - "@jax.jit\n", - "def mse(params, x_batched, y_batched):\n", - " # Define the squared loss for a single pair (x,y)\n", - " def squared_error(x, y):\n", - " pred = model.apply(params, x)\n", - " return jnp.inner(y-pred, y-pred) / 2.0\n", - " # Vectorize the previous to compute the average of the loss on all samples.\n", - " return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wGKru__mi15v" + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "ePEl1ndse0Jq", + "outputId": "50d975b3-4706-4d8a-c4b8-2629ab8e3ac4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss for \"true\" W,b: 0.023639778\n", + "Loss step 0: 38.094772\n", + "Loss step 10: 0.44692168\n", + "Loss step 20: 0.10053458\n", + "Loss step 30: 0.035822745\n", + "Loss step 40: 0.018846875\n", + "Loss step 50: 0.013864839\n", + "Loss step 60: 0.012312559\n", + "Loss step 70: 0.011812928\n", + "Loss step 80: 0.011649306\n", + "Loss step 90: 0.011595251\n", + "Loss step 100: 0.0115773035\n" + ] + } + ], + "source": [ + "learning_rate = 0.3 # Gradient step size.\n", + "print('Loss for \"true\" W,b: ', mse(true_params, x_samples, y_samples))\n", + "loss_grad_fn = jax.value_and_grad(mse)\n", + "\n", + "@jax.jit\n", + "def update_params(params, learning_rate, grads):\n", + " params = jax.tree_util.tree_map(\n", + " lambda p, g: p - learning_rate * g, params, grads)\n", + " return params\n", + "\n", + "for i in range(101):\n", + " # Perform one gradient update.\n", + " loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n", + " params = update_params(params, learning_rate, grads)\n", + " if i % 10 == 0:\n", + " print(f'Loss step {i}: ', loss_val)" + ] }, - "source": [ - "And finally perform the gradient descent." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "ePEl1ndse0Jq", - "outputId": "50d975b3-4706-4d8a-c4b8-2629ab8e3ac4" + { + "cell_type": "markdown", + "metadata": { + "id": "zqEnJ9Poyb6q" + }, + "source": [ + "### Optimizing with Optax\n", + "\n", + "Flax used to use its own `flax.optim` package for optimization, but with\n", + "[FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md)\n", + "this was deprecated in favor of\n", + "[Optax](https://github.com/deepmind/optax).\n", + "\n", + "Basic usage of Optax is straightforward:\n", + "\n", + "1. Choose an optimization method (e.g. `optax.adam`).\n", + "2. Create optimizer state from parameters (for the Adam optimizer, this state will contain the [momentum values](https://optax.readthedocs.io/en/latest/api.html#optax.adam)).\n", + "3. Compute the gradients of your loss with `jax.value_and_grad()`.\n", + "4. At every iteration, call the Optax `update` function to update the internal\n", + " optimizer state and create an update to the parameters. Then add the update\n", + " to the parameters with Optax's `apply_updates` method.\n", + "\n", + "Note that Optax can do a lot more: it's designed for composing simple gradient\n", + "transformations into more complex transformations that allows to implement a\n", + "wide range of optimizers. There is also support for changing optimizer\n", + "hyperparameters over time (\"schedules\"), applying different updates to different\n", + "parts of the parameter tree (\"masking\") and much more. For details please refer\n", + "to the\n", + "[official documentation](https://optax.readthedocs.io/en/latest/)." + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loss for \"true\" W,b: 0.023639778\n", - "Loss step 0: 38.094772\n", - "Loss step 10: 0.44692168\n", - "Loss step 20: 0.10053458\n", - "Loss step 30: 0.035822745\n", - "Loss step 40: 0.018846875\n", - "Loss step 50: 0.013864839\n", - "Loss step 60: 0.012312559\n", - "Loss step 70: 0.011812928\n", - "Loss step 80: 0.011649306\n", - "Loss step 90: 0.011595251\n", - "Loss step 100: 0.0115773035\n" - ] - } - ], - "source": [ - "learning_rate = 0.3 # Gradient step size.\n", - "print('Loss for \"true\" W,b: ', mse(true_params, x_samples, y_samples))\n", - "loss_grad_fn = jax.value_and_grad(mse)\n", - "\n", - "@jax.jit\n", - "def update_params(params, learning_rate, grads):\n", - " params = jax.tree_util.tree_map(\n", - " lambda p, g: p - learning_rate * g, params, grads)\n", - " return params\n", - "\n", - "for i in range(101):\n", - " # Perform one gradient update.\n", - " loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n", - " params = update_params(params, learning_rate, grads)\n", - " if i % 10 == 0:\n", - " print(f'Loss step {i}: ', loss_val)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zqEnJ9Poyb6q" + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "Ce77uDJx1bUF" + }, + "outputs": [], + "source": [ + "import optax\n", + "tx = optax.adam(learning_rate=learning_rate)\n", + "opt_state = tx.init(params)\n", + "loss_grad_fn = jax.value_and_grad(mse)" + ] }, - "source": [ - "### Optimizing with Optax\n", - "\n", - "Flax used to use its own `flax.optim` package for optimization, but with\n", - "[FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md)\n", - "this was deprecated in favor of\n", - "[Optax](https://github.com/deepmind/optax).\n", - "\n", - "Basic usage of Optax is straightforward:\n", - "\n", - "1. Choose an optimization method (e.g. `optax.adam`).\n", - "2. Create optimizer state from parameters (for the Adam optimizer, this state will contain the [momentum values](https://optax.readthedocs.io/en/latest/api.html#optax.adam)).\n", - "3. Compute the gradients of your loss with `jax.value_and_grad()`.\n", - "4. At every iteration, call the Optax `update` function to update the internal\n", - " optimizer state and create an update to the parameters. Then add the update\n", - " to the parameters with Optax's `apply_updates` method.\n", - "\n", - "Note that Optax can do a lot more: it's designed for composing simple gradient\n", - "transformations into more complex transformations that allows to implement a\n", - "wide range of optimizers. There is also support for changing optimizer\n", - "hyperparameters over time (\"schedules\"), applying different updates to different\n", - "parts of the parameter tree (\"masking\") and much more. For details please refer\n", - "to the\n", - "[official documentation](https://optax.readthedocs.io/en/latest/)." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "Ce77uDJx1bUF" + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "PTSv0vx13xPO", + "outputId": "eec0c096-1d9e-4b3c-f8e5-942ee63828ec" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss step 0: 0.011576377\n", + "Loss step 10: 0.0115710115\n", + "Loss step 20: 0.011569244\n", + "Loss step 30: 0.011568661\n", + "Loss step 40: 0.011568454\n", + "Loss step 50: 0.011568379\n", + "Loss step 60: 0.011568358\n", + "Loss step 70: 0.01156836\n", + "Loss step 80: 0.01156835\n", + "Loss step 90: 0.011568353\n", + "Loss step 100: 0.011568348\n" + ] + } + ], + "source": [ + "for i in range(101):\n", + " loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n", + " updates, opt_state = tx.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)\n", + " if i % 10 == 0:\n", + " print('Loss step {}: '.format(i), loss_val)" + ] }, - "outputs": [], - "source": [ - "import optax\n", - "tx = optax.adam(learning_rate=learning_rate)\n", - "opt_state = tx.init(params)\n", - "loss_grad_fn = jax.value_and_grad(mse)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "PTSv0vx13xPO", - "outputId": "eec0c096-1d9e-4b3c-f8e5-942ee63828ec" + { + "cell_type": "markdown", + "metadata": { + "id": "0eAPPwtpXYu7" + }, + "source": [ + "### Serializing the result\n", + "\n", + "Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that." + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loss step 0: 0.011576377\n", - "Loss step 10: 0.0115710115\n", - "Loss step 20: 0.011569244\n", - "Loss step 30: 0.011568661\n", - "Loss step 40: 0.011568454\n", - "Loss step 50: 0.011568379\n", - "Loss step 60: 0.011568358\n", - "Loss step 70: 0.01156836\n", - "Loss step 80: 0.01156835\n", - "Loss step 90: 0.011568353\n", - "Loss step 100: 0.011568348\n" - ] - } - ], - "source": [ - "for i in range(101):\n", - " loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n", - " updates, opt_state = tx.update(grads, opt_state)\n", - " params = optax.apply_updates(params, updates)\n", - " if i % 10 == 0:\n", - " print('Loss step {}: '.format(i), loss_val)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0eAPPwtpXYu7" + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "BiUPRU93XnAZ", + "outputId": "b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dict output\n", + "{'params': {'bias': DeviceArray([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547], dtype=float32), 'kernel': DeviceArray([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285,\n", + " 0.34720102],\n", + " [ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 ,\n", + " -0.10574618],\n", + " [-1.2009128 , 0.28837162, 1.4176372 , 0.12073109,\n", + " -1.3132601 ],\n", + " [-1.1944956 , -0.18993308, 0.03379077, 1.3165942 ,\n", + " 0.07996067],\n", + " [ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774,\n", + " -2.239638 ],\n", + " [ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 ,\n", + " 0.90352124],\n", + " [-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 ,\n", + " 0.9286919 ],\n", + " [ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117,\n", + " -1.202457 ],\n", + " [ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 ,\n", + " -0.45790705],\n", + " [-0.64384323, 0.4564892 , -1.1331053 , -0.68556863,\n", + " 0.17010891]], dtype=float32)}}\n", + "Bytes output\n", + "b'\\x81\\xa6params\\x82\\xa4bias\\xc7!\\x01\\x93\\x91\\x05\\xa7float32\\xc4\\x14\\x1d\\x1d\\xba\\xbf\\xc4\\xad\\x01\\xc0\\x81)\\x05@\\xdd.\\x9c?\\xa8\\x17\\x7f\\xbf\\xa6kernel\\xc7\\xd6\\x01\\x93\\x92\\n\\x05\\xa7float32\\xc4\\xc8\\x84]\\x81?\\xf0\\xb5B>`\\xb59=z^m\\xbfU\\xc4\\xb1>\\x00\\xb3\\xdd?\\xb8x}?\\xc7F\\x95?2(\\x8d?t\\x91\\xd8\\xbd\\x83\\xb7\\x99\\xbfr\\xa5\\x93>#u\\xb5?\\xdcA\\xf7=\\xe8\\x18\\xa8\\xbf;\\xe5\\x98\\xbf\\xd1}B\\xbe0h\\n=)\\x86\\xa8?k\\xc2\\xa3=\\xaaj\\x10>\\x91\\xd8\\xaf?\\xa9y\\xa8\\xbfc\\xb5\\x08?;V\\x0f\\xc0Av\\x10?ZHP?wD\\xa3>\\x022\\t?+Mg?\\xa0K\\xc2\\xbe\\xb1\\xd3\\xde?)\\x16\\x8a?\\x04\\x13\\x01\\xbf\\xc1\\xbem?\\xfdZx?Wn\\xa8\\xbf\\x940\\xac>\\x925O?\\x1c\\xea\\x99\\xbf\\x9e\\x89\\x82?\\x07\\xad\\x1e\\xbf\\xe2\\x87\\x8a?\\xdfU\\xeb\\xbf\\xcbr\\xea\\xbe\\xe9\\xd2$\\xbf\\xf4\\xb8\\xe9>\\x98\\t\\x91\\xbfm\\x81/\\xbf\\x081.>'\n" + ] + } + ], + "source": [ + "from flax import serialization\n", + "bytes_output = serialization.to_bytes(params)\n", + "dict_output = serialization.to_state_dict(params)\n", + "print('Dict output')\n", + "print(dict_output)\n", + "print('Bytes output')\n", + "print(bytes_output)" + ] }, - "source": [ - "### Serializing the result\n", - "\n", - "Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "BiUPRU93XnAZ", - "outputId": "b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c" + { + "cell_type": "markdown", + "metadata": { + "id": "eielPo2KZByd" + }, + "source": [ + "To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place.\n", + "\n", + "*The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.*" + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Dict output\n", - "{'params': {'bias': DeviceArray([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547], dtype=float32), 'kernel': DeviceArray([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285,\n", - " 0.34720102],\n", - " [ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 ,\n", - " -0.10574618],\n", - " [-1.2009128 , 0.28837162, 1.4176372 , 0.12073109,\n", - " -1.3132601 ],\n", - " [-1.1944956 , -0.18993308, 0.03379077, 1.3165942 ,\n", - " 0.07996067],\n", - " [ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774,\n", - " -2.239638 ],\n", - " [ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 ,\n", - " 0.90352124],\n", - " [-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 ,\n", - " 0.9286919 ],\n", - " [ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117,\n", - " -1.202457 ],\n", - " [ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 ,\n", - " -0.45790705],\n", - " [-0.64384323, 0.4564892 , -1.1331053 , -0.68556863,\n", - " 0.17010891]], dtype=float32)}}\n", - "Bytes output\n", - "b'\\x81\\xa6params\\x82\\xa4bias\\xc7!\\x01\\x93\\x91\\x05\\xa7float32\\xc4\\x14\\x1d\\x1d\\xba\\xbf\\xc4\\xad\\x01\\xc0\\x81)\\x05@\\xdd.\\x9c?\\xa8\\x17\\x7f\\xbf\\xa6kernel\\xc7\\xd6\\x01\\x93\\x92\\n\\x05\\xa7float32\\xc4\\xc8\\x84]\\x81?\\xf0\\xb5B>`\\xb59=z^m\\xbfU\\xc4\\xb1>\\x00\\xb3\\xdd?\\xb8x}?\\xc7F\\x95?2(\\x8d?t\\x91\\xd8\\xbd\\x83\\xb7\\x99\\xbfr\\xa5\\x93>#u\\xb5?\\xdcA\\xf7=\\xe8\\x18\\xa8\\xbf;\\xe5\\x98\\xbf\\xd1}B\\xbe0h\\n=)\\x86\\xa8?k\\xc2\\xa3=\\xaaj\\x10>\\x91\\xd8\\xaf?\\xa9y\\xa8\\xbfc\\xb5\\x08?;V\\x0f\\xc0Av\\x10?ZHP?wD\\xa3>\\x022\\t?+Mg?\\xa0K\\xc2\\xbe\\xb1\\xd3\\xde?)\\x16\\x8a?\\x04\\x13\\x01\\xbf\\xc1\\xbem?\\xfdZx?Wn\\xa8\\xbf\\x940\\xac>\\x925O?\\x1c\\xea\\x99\\xbf\\x9e\\x89\\x82?\\x07\\xad\\x1e\\xbf\\xe2\\x87\\x8a?\\xdfU\\xeb\\xbf\\xcbr\\xea\\xbe\\xe9\\xd2$\\xbf\\xf4\\xb8\\xe9>\\x98\\t\\x91\\xbfm\\x81/\\xbf\\x081.>'\n" - ] - } - ], - "source": [ - "from flax import serialization\n", - "bytes_output = serialization.to_bytes(params)\n", - "dict_output = serialization.to_state_dict(params)\n", - "print('Dict output')\n", - "print(dict_output)\n", - "print('Bytes output')\n", - "print(bytes_output)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eielPo2KZByd" + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "MOhoBDCOYYJ5", + "outputId": "13acc4e1-8757-4554-e2c8-d594ba6e67dc" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "FrozenDict({\n", + " params: {\n", + " bias: array([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547],\n", + " dtype=float32),\n", + " kernel: array([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285, 0.34720102],\n", + " [ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 , -0.10574618],\n", + " [-1.2009128 , 0.28837162, 1.4176372 , 0.12073109, -1.3132601 ],\n", + " [-1.1944956 , -0.18993308, 0.03379077, 1.3165942 , 0.07996067],\n", + " [ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774, -2.239638 ],\n", + " [ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 , 0.90352124],\n", + " [-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 , 0.9286919 ],\n", + " [ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117, -1.202457 ],\n", + " [ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 , -0.45790705],\n", + " [-0.64384323, 0.4564892 , -1.1331053 , -0.68556863, 0.17010891]],\n", + " dtype=float32),\n", + " },\n", + "})" + ] + }, + "execution_count": 14, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "serialization.from_bytes(params, bytes_output)" + ] }, - "source": [ - "To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place.\n", - "\n", - "*The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.*" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "MOhoBDCOYYJ5", - "outputId": "13acc4e1-8757-4554-e2c8-d594ba6e67dc" + { + "cell_type": "markdown", + "metadata": { + "id": "8mNu8nuOhDC5" + }, + "source": [ + "## Defining your own models\n", + "\n", + "Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class.\n", + "\n", + "*Keep in mind that we imported* `linen as nn` *and this only works with the new linen API*" + ] }, - "outputs": [ - { - "data": { - "text/plain": [ - "FrozenDict({\n", - " params: {\n", - " bias: array([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547],\n", - " dtype=float32),\n", - " kernel: array([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285, 0.34720102],\n", - " [ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 , -0.10574618],\n", - " [-1.2009128 , 0.28837162, 1.4176372 , 0.12073109, -1.3132601 ],\n", - " [-1.1944956 , -0.18993308, 0.03379077, 1.3165942 , 0.07996067],\n", - " [ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774, -2.239638 ],\n", - " [ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 , 0.90352124],\n", - " [-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 , 0.9286919 ],\n", - " [ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117, -1.202457 ],\n", - " [ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 , -0.45790705],\n", - " [-0.64384323, 0.4564892 , -1.1331053 , -0.68556863, 0.17010891]],\n", - " dtype=float32),\n", - " },\n", - "})" - ] - }, - "execution_count": 14, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "serialization.from_bytes(params, bytes_output)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8mNu8nuOhDC5" + { + "cell_type": "markdown", + "metadata": { + "id": "1sllHAdRlpmQ" + }, + "source": [ + "### Module basics\n", + "\n", + "The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function." + ] }, - "source": [ - "## Defining your own models\n", - "\n", - "Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class.\n", - "\n", - "*Keep in mind that we imported* `linen as nn` *and this only works with the new linen API*" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1sllHAdRlpmQ" + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "vbfrfbkxgPhg", + "outputId": "b59c679c-d164-4fd6-92db-b50f0d310ec3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", + "output:\n", + " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", + " -1.7147182e-02]\n", + " [ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02\n", + " -4.5417298e-02]\n", + " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", + " 0.0000000e+00]\n", + " [ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04\n", + " -1.0110770e-03]]\n" + ] + } + ], + "source": [ + "class ExplicitMLP(nn.Module):\n", + " features: Sequence[int]\n", + "\n", + " def setup(self):\n", + " # we automatically know what to do with lists, dicts of submodules\n", + " self.layers = [nn.Dense(feat) for feat in self.features]\n", + " # for single submodules, we would just write:\n", + " # self.layer1 = nn.Dense(feat1)\n", + "\n", + " def __call__(self, inputs):\n", + " x = inputs\n", + " for i, lyr in enumerate(self.layers):\n", + " x = lyr(x)\n", + " if i != len(self.layers) - 1:\n", + " x = nn.relu(x)\n", + " return x\n", + "\n", + "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = ExplicitMLP(features=[3,4,5])\n", + "params = model.init(key2, x)\n", + "y = model.apply(params, x)\n", + "\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", + "print('output:\\n', y)" + ] }, - "source": [ - "### Module basics\n", - "\n", - "The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "vbfrfbkxgPhg", - "outputId": "b59c679c-d164-4fd6-92db-b50f0d310ec3" + { + "cell_type": "markdown", + "metadata": { + "id": "DDITIjXitEZl" + }, + "source": [ + "As we can see, a `nn.Module` subclass is made of:\n", + "\n", + "* A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`.\n", + "* A `setup()` method that is being called at the end of the `__postinit__` where you can register submodules, variables, parameters you will need in your model.\n", + "* A `__call__` function that returns the output of the model from a given input.\n", + "* The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one `layers_n` sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit.\n", + "\n", + "*Note: lists are mostly managed as you would expect (WIP), there are corner cases you should be aware of as pointed out* [here](https://github.com/google/flax/issues/524)\n", + "\n", + "Since the module structure and its parameters are not tied to each other, you can't directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one, which is the one to call on an input:" + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "initialized parameter shapes:\n", - " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", - "output:\n", - " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", - " -1.7147182e-02]\n", - " [ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02\n", - " -4.5417298e-02]\n", - " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", - " 0.0000000e+00]\n", - " [ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04\n", - " -1.0110770e-03]]\n" - ] - } - ], - "source": [ - "class ExplicitMLP(nn.Module):\n", - " features: Sequence[int]\n", - "\n", - " def setup(self):\n", - " # we automatically know what to do with lists, dicts of submodules\n", - " self.layers = [nn.Dense(feat) for feat in self.features]\n", - " # for single submodules, we would just write:\n", - " # self.layer1 = nn.Dense(feat1)\n", - "\n", - " def __call__(self, inputs):\n", - " x = inputs\n", - " for i, lyr in enumerate(self.layers):\n", - " x = lyr(x)\n", - " if i != len(self.layers) - 1:\n", - " x = nn.relu(x)\n", - " return x\n", - "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", - "x = random.uniform(key1, (4,4))\n", - "\n", - "model = ExplicitMLP(features=[3,4,5])\n", - "params = model.init(key2, x)\n", - "y = model.apply(params, x)\n", - "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", - "print('output:\\n', y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DDITIjXitEZl" + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "DEYrVA6dnaJu", + "outputId": "4af16ec5-b52a-43b0-fc47-1f8ab25e7058" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"ExplicitMLP\" object has no attribute \"layers\"\n" + ] + } + ], + "source": [ + "try:\n", + " y = model(x) # Returns an error\n", + "except AttributeError as e:\n", + " print(e)" + ] }, - "source": [ - "As we can see, a `nn.Module` subclass is made of:\n", - "\n", - "* A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`.\n", - "* A `setup()` method that is being called at the end of the `__postinit__` where you can register submodules, variables, parameters you will need in your model.\n", - "* A `__call__` function that returns the output of the model from a given input.\n", - "* The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one `layers_n` sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit.\n", - "\n", - "*Note: lists are mostly managed as you would expect (WIP), there are corner cases you should be aware of as pointed out* [here](https://github.com/google/flax/issues/524)\n", - "\n", - "Since the module structure and its parameters are not tied to each other, you can't directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one, which is the one to call on an input:" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "DEYrVA6dnaJu", - "outputId": "4af16ec5-b52a-43b0-fc47-1f8ab25e7058" + { + "cell_type": "markdown", + "metadata": { + "id": "I__UrmShnaJu" + }, + "source": [ + "Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:" + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\"ExplicitMLP\" object has no attribute \"layers\"\n" - ] - } - ], - "source": [ - "try:\n", - " y = model(x) # Returns an error\n", - "except AttributeError as e:\n", - " print(e)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I__UrmShnaJu" + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "ZTCbdpQ4suSK", + "outputId": "183a74ef-f54e-4848-99bf-fee4c174ba6d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", + "output:\n", + " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", + " -1.7147182e-02]\n", + " [ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02\n", + " -4.5417298e-02]\n", + " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", + " 0.0000000e+00]\n", + " [ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04\n", + " -1.0110770e-03]]\n" + ] + } + ], + "source": [ + "class SimpleMLP(nn.Module):\n", + " features: Sequence[int]\n", + "\n", + " @nn.compact\n", + " def __call__(self, inputs):\n", + " x = inputs\n", + " for i, feat in enumerate(self.features):\n", + " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", + " if i != len(self.features) - 1:\n", + " x = nn.relu(x)\n", + " # providing a name is optional though!\n", + " # the default autonames would be \"Dense_0\", \"Dense_1\", ...\n", + " return x\n", + "\n", + "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = SimpleMLP(features=[3,4,5])\n", + "params = model.init(key2, x)\n", + "y = model.apply(params, x)\n", + "\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", + "print('output:\\n', y)" + ] }, - "source": [ - "Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "ZTCbdpQ4suSK", - "outputId": "183a74ef-f54e-4848-99bf-fee4c174ba6d" + { + "cell_type": "markdown", + "metadata": { + "id": "es7YHjgexT-L" + }, + "source": [ + "There are, however, a few differences you should be aware of between the two declaration modes:\n", + "\n", + "* In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders).\n", + "* If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated.\n", + "* The last initialization will be handled differently. See these notes for more details (TODO: add notes link)." + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "initialized parameter shapes:\n", - " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", - "output:\n", - " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", - " -1.7147182e-02]\n", - " [ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02\n", - " -4.5417298e-02]\n", - " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", - " 0.0000000e+00]\n", - " [ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04\n", - " -1.0110770e-03]]\n" - ] - } - ], - "source": [ - "class SimpleMLP(nn.Module):\n", - " features: Sequence[int]\n", - "\n", - " @nn.compact\n", - " def __call__(self, inputs):\n", - " x = inputs\n", - " for i, feat in enumerate(self.features):\n", - " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", - " if i != len(self.features) - 1:\n", - " x = nn.relu(x)\n", - " # providing a name is optional though!\n", - " # the default autonames would be \"Dense_0\", \"Dense_1\", ...\n", - " return x\n", - "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", - "x = random.uniform(key1, (4,4))\n", - "\n", - "model = SimpleMLP(features=[3,4,5])\n", - "params = model.init(key2, x)\n", - "y = model.apply(params, x)\n", - "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", - "print('output:\\n', y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "es7YHjgexT-L" + { + "cell_type": "markdown", + "metadata": { + "id": "-ykceROJyp7W" + }, + "source": [ + "### Module parameters\n", + "\n", + "In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules:" + ] }, - "source": [ - "There are, however, a few differences you should be aware of between the two declaration modes:\n", - "\n", - "* In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders).\n", - "* If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated.\n", - "* The last initialization will be handled differently. See these notes for more details (TODO: add notes link)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-ykceROJyp7W" + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "wK371Pt_vVfR", + "outputId": "83b5fea4-071e-4ea0-8fa8-610e69fb5fd5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameters:\n", + " FrozenDict({\n", + " params: {\n", + " kernel: DeviceArray([[ 0.6503669 , 0.86789787, 0.4604268 ],\n", + " [ 0.05673932, 0.9909285 , -0.63536596],\n", + " [ 0.76134115, -0.3250529 , -0.65221626],\n", + " [-0.82430327, 0.4150194 , 0.19405058]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", + " },\n", + "})\n", + "output:\n", + " [[ 0.5035518 1.8548558 -0.4270195 ]\n", + " [ 0.0279097 0.5589246 -0.43061772]\n", + " [ 0.3547128 1.5740999 -0.32865518]\n", + " [ 0.5264864 1.2928858 0.10089308]]\n" + ] + } + ], + "source": [ + "class SimpleDense(nn.Module):\n", + " features: int\n", + " kernel_init: Callable = nn.initializers.lecun_normal()\n", + " bias_init: Callable = nn.initializers.zeros_init()\n", + "\n", + " @nn.compact\n", + " def __call__(self, inputs):\n", + " kernel = self.param('kernel',\n", + " self.kernel_init, # Initialization function\n", + " (inputs.shape[-1], self.features)) # shape info.\n", + " y = lax.dot_general(inputs, kernel,\n", + " (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?\n", + " bias = self.param('bias', self.bias_init, (self.features,))\n", + " y = y + bias\n", + " return y\n", + "\n", + "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = SimpleDense(features=3)\n", + "params = model.init(key2, x)\n", + "y = model.apply(params, x)\n", + "\n", + "print('initialized parameters:\\n', params)\n", + "print('output:\\n', y)" + ] }, - "source": [ - "### Module parameters\n", - "\n", - "In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules:" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "wK371Pt_vVfR", - "outputId": "83b5fea4-071e-4ea0-8fa8-610e69fb5fd5" + { + "cell_type": "markdown", + "metadata": { + "id": "MKyhfzVpzC94" + }, + "source": [ + "Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` :\n", + "\n", + "* `name` is simply the name of the parameter that will end up in the parameter structure.\n", + "* `init_fn` is a function with input `(PRNGKey, *init_args)` returning an Array, with `init_args` being the arguments needed to call the initialisation function.\n", + "* `init_args` are the arguments to provide to the initialization function.\n", + "\n", + "Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site." + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "initialized parameters:\n", - " FrozenDict({\n", - " params: {\n", - " kernel: DeviceArray([[ 0.6503669 , 0.86789787, 0.4604268 ],\n", - " [ 0.05673932, 0.9909285 , -0.63536596],\n", - " [ 0.76134115, -0.3250529 , -0.65221626],\n", - " [-0.82430327, 0.4150194 , 0.19405058]], dtype=float32),\n", - " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", - " },\n", - "})\n", - "output:\n", - " [[ 0.5035518 1.8548558 -0.4270195 ]\n", - " [ 0.0279097 0.5589246 -0.43061772]\n", - " [ 0.3547128 1.5740999 -0.32865518]\n", - " [ 0.5264864 1.2928858 0.10089308]]\n" - ] - } - ], - "source": [ - "class SimpleDense(nn.Module):\n", - " features: int\n", - " kernel_init: Callable = nn.initializers.lecun_normal()\n", - " bias_init: Callable = nn.initializers.zeros_init()\n", - "\n", - " @nn.compact\n", - " def __call__(self, inputs):\n", - " kernel = self.param('kernel',\n", - " self.kernel_init, # Initialization function\n", - " (inputs.shape[-1], self.features)) # shape info.\n", - " y = lax.dot_general(inputs, kernel,\n", - " (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?\n", - " bias = self.param('bias', self.bias_init, (self.features,))\n", - " y = y + bias\n", - " return y\n", - "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", - "x = random.uniform(key1, (4,4))\n", - "\n", - "model = SimpleDense(features=3)\n", - "params = model.init(key2, x)\n", - "y = model.apply(params, x)\n", - "\n", - "print('initialized parameters:\\n', params)\n", - "print('output:\\n', y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MKyhfzVpzC94" + { + "cell_type": "markdown", + "metadata": { + "id": "QmSpxyqLDr58" + }, + "source": [ + "### Variables and collections of variables\n", + "\n", + "As we've seen so far, working with models means working with:\n", + "\n", + "* A subclass of `nn.Module`;\n", + "* A pytree of parameters for the model (typically from `model.init()`);\n", + "\n", + "However this is not enough to cover everything that we would need for machine learning, especially neural networks. In some cases, you might want your neural network to keep track of some internal state while it runs (e.g. batch normalization layers). There is a way to declare variables beyond the parameters of the model with the `variable` method.\n", + "\n", + "For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py)." + ] }, - "source": [ - "Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` :\n", - "\n", - "* `name` is simply the name of the parameter that will end up in the parameter structure.\n", - "* `init_fn` is a function with input `(PRNGKey, *init_args)` returning an Array, with `init_args` being the arguments needed to call the initialisation function.\n", - "* `init_args` are the arguments to provide to the initialization function.\n", - "\n", - "Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QmSpxyqLDr58" + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "J6_tR-nPzB1i", + "outputId": "75465fd6-cdc8-497c-a3ec-7f709b5dde7a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized variables:\n", + " FrozenDict({\n", + " batch_stats: {\n", + " mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n", + " },\n", + " params: {\n", + " bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n", + " },\n", + "})\n", + "updated state:\n", + " FrozenDict({\n", + " batch_stats: {\n", + " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", + " },\n", + "})\n" + ] + } + ], + "source": [ + "class BiasAdderWithRunningMean(nn.Module):\n", + " decay: float = 0.99\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " # easy pattern to detect if we're initializing via empty variable tree\n", + " is_initialized = self.has_variable('batch_stats', 'mean')\n", + " ra_mean = self.variable('batch_stats', 'mean',\n", + " lambda s: jnp.zeros(s),\n", + " x.shape[1:])\n", + " mean = ra_mean.value # This will either get the value or trigger init\n", + " bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])\n", + " if is_initialized:\n", + " ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)\n", + "\n", + " return x - ra_mean.value + bias\n", + "\n", + "\n", + "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "x = jnp.ones((10,5))\n", + "model = BiasAdderWithRunningMean()\n", + "variables = model.init(key1, x)\n", + "print('initialized variables:\\n', variables)\n", + "y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", + "print('updated state:\\n', updated_state)" + ] }, - "source": [ - "### Variables and collections of variables\n", - "\n", - "As we've seen so far, working with models means working with:\n", - "\n", - "* A subclass of `nn.Module`;\n", - "* A pytree of parameters for the model (typically from `model.init()`);\n", - "\n", - "However this is not enough to cover everything that we would need for machine learning, especially neural networks. In some cases, you might want your neural network to keep track of some internal state while it runs (e.g. batch normalization layers). There is a way to declare variables beyond the parameters of the model with the `variable` method.\n", - "\n", - "For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py)." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "J6_tR-nPzB1i", - "outputId": "75465fd6-cdc8-497c-a3ec-7f709b5dde7a" + { + "cell_type": "markdown", + "metadata": { + "id": "5OHBbMJng3ic" + }, + "source": [ + "Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:" + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "initialized variables:\n", - " FrozenDict({\n", - " batch_stats: {\n", - " mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n", - " },\n", - " params: {\n", - " bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n", - " },\n", - "})\n", - "updated state:\n", - " FrozenDict({\n", - " batch_stats: {\n", - " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", - " },\n", - "})\n" - ] - } - ], - "source": [ - "class BiasAdderWithRunningMean(nn.Module):\n", - " decay: float = 0.99\n", - "\n", - " @nn.compact\n", - " def __call__(self, x):\n", - " # easy pattern to detect if we're initializing via empty variable tree\n", - " is_initialized = self.has_variable('batch_stats', 'mean')\n", - " ra_mean = self.variable('batch_stats', 'mean',\n", - " lambda s: jnp.zeros(s),\n", - " x.shape[1:])\n", - " mean = ra_mean.value # This will either get the value or trigger init\n", - " bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])\n", - " if is_initialized:\n", - " ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)\n", - "\n", - " return x - ra_mean.value + bias\n", - "\n", - "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", - "x = jnp.ones((10,5))\n", - "model = BiasAdderWithRunningMean()\n", - "variables = model.init(key1, x)\n", - "print('initialized variables:\\n', variables)\n", - "y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", - "print('updated state:\\n', updated_state)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5OHBbMJng3ic" + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "IbTsCAvZcdBy", + "outputId": "09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "updated state:\n", + " FrozenDict({\n", + " batch_stats: {\n", + " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", + " },\n", + "})\n", + "updated state:\n", + " FrozenDict({\n", + " batch_stats: {\n", + " mean: DeviceArray([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),\n", + " },\n", + "})\n", + "updated state:\n", + " FrozenDict({\n", + " batch_stats: {\n", + " mean: DeviceArray([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),\n", + " },\n", + "})\n" + ] + } + ], + "source": [ + "for val in [1.0, 2.0, 3.0]:\n", + " x = val * jnp.ones((10,5))\n", + " y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", + " old_state, params = variables.pop('params')\n", + " variables = freeze({'params': params, **updated_state})\n", + " print('updated state:\\n', updated_state) # Shows only the mutable part" + ] }, - "source": [ - "Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "id": "IbTsCAvZcdBy", - "outputId": "09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b" + { + "cell_type": "markdown", + "metadata": { + "id": "GuUSOSKegKIM" + }, + "source": [ + "From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables.\n", + "\n", + "*This example isn't doing anything and is only for demonstration purposes.*" + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "updated state:\n", - " FrozenDict({\n", - " batch_stats: {\n", - " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", - " },\n", - "})\n", - "updated state:\n", - " FrozenDict({\n", - " batch_stats: {\n", - " mean: DeviceArray([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),\n", - " },\n", - "})\n", - "updated state:\n", - " FrozenDict({\n", - " batch_stats: {\n", - " mean: DeviceArray([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),\n", - " },\n", - "})\n" - ] - } - ], - "source": [ - "for val in [1.0, 2.0, 3.0]:\n", - " x = val * jnp.ones((10,5))\n", - " y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", - " old_state, params = variables.pop('params')\n", - " variables = freeze({'params': params, **updated_state})\n", - " print('updated state:\\n', updated_state) # Shows only the mutable part" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GuUSOSKegKIM" + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "id": "TUgAbUPpnaJw", + "outputId": "0906fbab-b866-4956-d231-b1374415d448" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Updated state: FrozenDict({\n", + " batch_stats: {\n", + " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", + " },\n", + "})\n", + "Updated state: FrozenDict({\n", + " batch_stats: {\n", + " mean: DeviceArray([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),\n", + " },\n", + "})\n", + "Updated state: FrozenDict({\n", + " batch_stats: {\n", + " mean: DeviceArray([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),\n", + " },\n", + "})\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "\n", + "@partial(jax.jit, static_argnums=(0, 1))\n", + "def update_step(tx, apply_fn, x, opt_state, params, state):\n", + "\n", + " def loss(params):\n", + " y, updated_state = apply_fn({'params': params, **state},\n", + " x, mutable=list(state.keys()))\n", + " l = ((x - y) ** 2).sum()\n", + " return l, updated_state\n", + "\n", + " (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)\n", + " updates, opt_state = tx.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)\n", + " return opt_state, params, state\n", + "\n", + "x = jnp.ones((10,5))\n", + "variables = model.init(random.PRNGKey(0), x)\n", + "state, params = variables.pop('params')\n", + "del variables\n", + "tx = optax.sgd(learning_rate=0.02)\n", + "opt_state = tx.init(params)\n", + "\n", + "for _ in range(3):\n", + " opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)\n", + " print('Updated state: ', state)" + ] }, - "source": [ - "From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables.\n", - "\n", - "*This example isn't doing anything and is only for demonstration purposes.*" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "id": "TUgAbUPpnaJw", - "outputId": "0906fbab-b866-4956-d231-b1374415d448" + { + "cell_type": "markdown", + "metadata": { + "id": "eWUmx5EjtWge" + }, + "source": [ + "Note that the above function has a quite verbose signature and it would not actually\n", + "work with `jax.jit()` because the function arguments are not \"valid JAX types\".\n", + "\n", + "Flax provides a handy wrapper - `TrainState` - that simplifies the above code. Check out [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to learn more." + ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Updated state: FrozenDict({\n", - " batch_stats: {\n", - " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", - " },\n", - "})\n", - "Updated state: FrozenDict({\n", - " batch_stats: {\n", - " mean: DeviceArray([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),\n", - " },\n", - "})\n", - "Updated state: FrozenDict({\n", - " batch_stats: {\n", - " mean: DeviceArray([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),\n", - " },\n", - "})\n" - ] - } - ], - "source": [ - "from functools import partial\n", - "\n", - "@partial(jax.jit, static_argnums=(0, 1))\n", - "def update_step(tx, apply_fn, x, opt_state, params, state):\n", - "\n", - " def loss(params):\n", - " y, updated_state = apply_fn({'params': params, **state},\n", - " x, mutable=list(state.keys()))\n", - " l = ((x - y) ** 2).sum()\n", - " return l, updated_state\n", - "\n", - " (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)\n", - " updates, opt_state = tx.update(grads, opt_state)\n", - " params = optax.apply_updates(params, updates)\n", - " return opt_state, params, state\n", - "\n", - "x = jnp.ones((10,5))\n", - "variables = model.init(random.PRNGKey(0), x)\n", - "state, params = variables.pop('params')\n", - "del variables\n", - "tx = optax.sgd(learning_rate=0.02)\n", - "opt_state = tx.init(params)\n", - "\n", - "for _ in range(3):\n", - " opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)\n", - " print('Updated state: ', state)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eWUmx5EjtWge" + { + "cell_type": "markdown", + "metadata": { + "id": "_GL0PsCwnaJw" + }, + "source": [ + "### Exporting to Tensorflow's SavedModel with jax2tf\n", + "\n", + "JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" }, - "source": [ - "Note that the above function has a quite verbose signature and it would not actually\n", - "work with `jax.jit()` because the function arguments are not \"valid JAX types\".\n", - "\n", - "Flax provides a handy wrapper - `TrainState` - that simplifies the above code. Check out [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to learn more." - ] + "language_info": { + "name": "python", + "version": "3.8.15" + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "_GL0PsCwnaJw" - }, - "source": [ - "### Exporting to Tensorflow's SavedModel with jax2tf\n", - "\n", - "JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." - ] - } - ], - "metadata": { - "jupytext": { - "formats": "ipynb,md:myst" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + "nbformat": 4, + "nbformat_minor": 0 + } \ No newline at end of file diff --git a/docs/guides/flax_basics.md b/docs/guides/flax_basics.md index 686177df28..1a1889d7ec 100644 --- a/docs/guides/flax_basics.md +++ b/docs/guides/flax_basics.md @@ -96,22 +96,6 @@ The result is what we expect: bias and kernel parameters of the correct size. Un * Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`. * The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`. -+++ {"id": "3yL9mKk7naJn"} - -The output shows that the parameters are stored in a `FrozenDict` instance, which helps deal with the functional nature of JAX by preventing any mutation of the underlying dict and making the user aware of it. Read more about it in the [`flax.core.frozen_dict.FrozenDict` API docs](https://flax.readthedocs.io/en/latest/api_reference/flax.core.frozen_dict.html#flax.core.frozen_dict.FrozenDict). - -As a consequence, the following doesn't work: - -```{code-cell} -:id: HtOFWeiynaJo -:outputId: 689b4230-2a3d-4823-d103-2858e6debc4d - -try: - params['new_key'] = jnp.ones((2,2)) -except ValueError as e: - print("Error: ", e) -``` - +++ {"id": "M1qo9M3_naJo"} To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input: @@ -546,4 +530,4 @@ Flax provides a handy wrapper - `TrainState` - that simplifies the above code. C ### Exporting to Tensorflow's SavedModel with jax2tf -JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. +JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. \ No newline at end of file diff --git a/flax/configurations.py b/flax/configurations.py index d56e6a8022..a6e0ac2470 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -109,11 +109,12 @@ def temp_flip_flag(var_name: str, var_value: bool): default=False, help=("When adopting outside modules, don't clobber existing names.")) -#TODO(marcuschiam): remove this feature flag once regular dict migration is complete +# TODO(marcuschiam): remove this feature flag once regular dict migration is complete flax_return_frozendict = define_bool_state( name='return_frozendict', - default=True, - help=('Whether to return FrozenDicts when calling init or apply.')) + default=False, + help='Whether to return FrozenDicts when calling init or apply.', +) flax_fix_rng = define_bool_state( name ='fix_rng_separator', diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index c07fc84ec6..1794e35478 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -164,7 +164,6 @@ def false_fn(scope, x): self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1}) np.testing.assert_allclose(y1, -y2) - @temp_flip_flag('return_frozendict', False) def test_switch(self): def f(scope, x, index): scope.variable('state', 'a_count', lambda: 0) diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 5ed4ede7c7..e16346c93b 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -144,7 +144,6 @@ def test_causal_mask_1d(self): np.testing.assert_allclose(mask_1d, mask_1d_simple,) @parameterized.parameters([((5,), (1,)), ((6, 5), (2,))]) - @temp_flip_flag('return_frozendict', False) def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index e48122c558..dd14f9e216 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -1098,7 +1098,6 @@ def test(self): A().test() self.assertFalse(setup_called) - @temp_flip_flag('return_frozendict', False) def test_module_pass_as_attr(self): class A(nn.Module): @@ -1129,7 +1128,6 @@ def __call__(self, x): } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_module_pass_in_closure(self): a = nn.Dense(2) @@ -1154,7 +1152,6 @@ def __call__(self, x): self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) self.assertIsNone(a.name) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_adoption(self): class Encoder(nn.Module): @@ -1210,7 +1207,6 @@ def __call__(self, x): } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_adoption_pytree(self): class A(nn.Module): @@ -1254,7 +1250,6 @@ def __call__(self, c, x): lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), counters, ref_counters))) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_adoption_sharing(self): dense = functools.partial(nn.Dense, use_bias=False) @@ -1305,7 +1300,6 @@ def __call__(self, x): } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_toplevel_named_submodule_adoption(self): dense = functools.partial(nn.Dense, use_bias=False) @@ -1360,7 +1354,6 @@ def __call__(self, x): } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_pytree_adoption_sharing(self): class A(nn.Module): diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index ee1a04bde5..6fb7c97b3d 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -807,7 +807,6 @@ def __call__(self, x): y3 = Ctrafo(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y3, atol=1e-7) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_adoption_pytree_transform(self): class A(nn.Module): @nn.compact @@ -852,7 +851,6 @@ def __call__(self, c, x): cntrs, ref_cntrs) )) - @temp_flip_flag('return_frozendict', False) def test_partially_applied_module_constructor_transform(self): k = random.PRNGKey(0) x = jnp.ones((3,4,4)) @@ -870,7 +868,6 @@ def test_partially_applied_module_constructor_transform(self): } self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_partial_module_method(self): k = random.PRNGKey(0) x = jnp.ones((3,4,4)) @@ -1505,7 +1502,6 @@ def false_fn(mdl, x): return nn.cond(pred, true_fn, false_fn, self, x) - @temp_flip_flag('return_frozendict', False) def test_switch(self): class Foo(nn.Module): @nn.compact @@ -1540,7 +1536,6 @@ def c_fn(mdl, x): self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1}) np.testing.assert_allclose(y1, y3) - @temp_flip_flag('return_frozendict', False) def test_switch_multihead(self): class Foo(nn.Module): def setup(self) -> None: diff --git a/tests/linen/summary_test.py b/tests/linen/summary_test.py index 9f7049f3a9..f20924ad22 100644 --- a/tests/linen/summary_test.py +++ b/tests/linen/summary_test.py @@ -181,7 +181,6 @@ def test_module_summary(self): row.counted_variables, ) - @temp_flip_flag('return_frozendict', False) def test_module_summary_with_depth(self): """ This test creates a Table using `module_summary` set the `depth` argument to `1`, @@ -240,7 +239,6 @@ def test_module_summary_with_depth(self): self.assertEqual(table[3].module_variables, table[3].counted_variables) - @temp_flip_flag('return_frozendict', False) def test_tabulate(self): """ This test creates a string representation of a Module using `Module.tabulate` @@ -323,7 +321,6 @@ def test_tabulate_with_method(self): self.assertIn("(block_method)", module_repr) self.assertIn("(cnn_method)", module_repr) - @temp_flip_flag('return_frozendict', False) def test_tabulate_function(self): """ This test creates a string representation of a Module using `Module.tabulate` @@ -370,7 +367,6 @@ def test_tabulate_function(self): self.assertIn("79.4 KB", lines[-3]) - @temp_flip_flag('return_frozendict', False) def test_lifted_transform(self): class LSTM(nn.Module): features: int @@ -406,7 +402,6 @@ def __call__(self, x): self.assertIn("ScanLSTM/ii", lines[13]) self.assertIn("Dense", lines[13]) - @temp_flip_flag('return_frozendict', False) def test_lifted_transform_no_rename(self): class LSTM(nn.Module): features: int @@ -442,7 +437,6 @@ def __call__(self, x): self.assertIn("ScanLSTMCell_0/ii", lines[13]) self.assertIn("Dense", lines[13]) - @temp_flip_flag('return_frozendict', False) def test_module_reuse(self): class ConvBlock(nn.Module): @nn.compact @@ -524,7 +518,6 @@ def __call__(self): self.assertIn('x: 3.141592', lines[7]) self.assertIn('4.141592', lines[7]) - @temp_flip_flag('return_frozendict', False) def test_partitioned_params(self): class Classifier(nn.Module):