diff --git a/docs_nnx/guides/index.rst b/docs_nnx/guides/index.rst index c47a1b9f46..0bb012983d 100644 --- a/docs_nnx/guides/index.rst +++ b/docs_nnx/guides/index.rst @@ -5,6 +5,7 @@ Guides :maxdepth: 2 filters_guide + randomness linen_to_nnx bridge_guide surgery diff --git a/docs_nnx/guides/randomness.ipynb b/docs_nnx/guides/randomness.ipynb new file mode 100644 index 0000000000..a14b4ea806 --- /dev/null +++ b/docs_nnx/guides/randomness.ipynb @@ -0,0 +1,555 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Randomness\n", + "\n", + "* intro: randomness as state\n", + "* Rngs, RngStream, and RngState\n", + " * Bullet point list explaining each\n", + " * Example of nnx.display to show the structure\n", + " * Generating keys\n", + " * stream names and default stream\n", + "* Filtering random state\n", + "* Reseeding\n", + "* Splitting Rngs\n", + "* Transforms" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Introduction" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import jax\n", + "from jax import random, numpy as jnp" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rngs, RngStream, and RngState\n", + "\n", + "* **Rngs**:\n", + "* **RngStream**:\n", + "* **RngState**:\n", + " * **RngKey**:\n", + " * **RngCount**:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "rngs = nnx.Rngs(params=0, dropout=1)\n", + "nnx.display(rngs)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "params_key = rngs.params()\n", + "dropout_key = rngs.dropout()\n", + "\n", + "nnx.display(rngs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Default stream" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "rngs = nnx.Rngs(0, params=1)\n", + "\n", + "key1 = rngs.params() # call params\n", + "key2 = rngs.dropout() # fallback to default\n", + "key3 = rngs() # call default directly\n", + "\n", + "nnx.display(rngs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Filtering random state" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class Foo(nnx.Module):\n", + " def __init__(self, rngs: nnx.Rngs):\n", + " self.rngs = rngs # could be nested inside other Modules\n", + "\n", + "model = Foo(nnx.Rngs(params=0, dropout=1))\n", + "\n", + "rng_state = nnx.state(model, nnx.RngState) # all random state\n", + "key_state = nnx.state(model, nnx.RngKey) # only keys\n", + "count_state = nnx.state(model, nnx.RngCount) # only counts\n", + "rng_params_state = nnx.state(model, 'params') # only params\n", + "rng_dropout_state = nnx.state(model, 'dropout') # only dropout\n", + "params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys\n", + "\n", + "nnx.display(params_key_state)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reseeding" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class Model(nnx.Module):\n", + " def __init__(self, rngs: nnx.Rngs):\n", + " self.linear = nnx.Linear(20, 10, rngs=rngs)\n", + " self.drop = nnx.Dropout(0.1, rngs=rngs)\n", + "\n", + " def __call__(self, x):\n", + " return nnx.relu(self.drop(self.linear(x)))\n", + "\n", + "block = Model(nnx.Rngs(params=0, dropout=1))\n", + "x = jnp.ones((1, 20))\n", + "\n", + "y1 = block(x)\n", + "y2 = block(x)\n", + "\n", + "nnx.reseed(block, dropout=1) # reset dropout RngState\n", + "y3 = block(x)\n", + "\n", + "assert not jnp.allclose(y1, y2) # different\n", + "assert jnp.allclose(y1, y3) # same" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Splitting Rngs" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "split_key = random.split(random.key(0), 5)\n", + "rngs = nnx.Rngs(params=0, dropout=split_key)\n", + "nnx.display(rngs)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fold_in accepts a single key, but was given a key array of shape (5,) != (). Use jax.vmap for batching.\n" + ] + } + ], + "source": [ + "try:\n", + " rngs.dropout()\n", + "except ValueError as e:\n", + " print(e)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inside:\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "rngs = nnx.Rngs(params=0, dropout=1)\n", + "\n", + "@nnx.split_rngs(splits=5, only='dropout')\n", + "def f(rngs: nnx.Rngs):\n", + " print('Inside:')\n", + " nnx.display(rngs)\n", + " # rngs.dropout() # ValueError: fold_in accepts a single key...\n", + "\n", + "f(rngs)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Outside:\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print('Outside:')\n", + "nnx.display(rngs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Transforms\n", + "\n", + "### Data-parallel dropout" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class Model(nnx.Module):\n", + " def __init__(self, rngs: nnx.Rngs):\n", + " self.linear = nnx.Linear(20, 10, rngs=rngs)\n", + " self.drop = nnx.Dropout(0.1, rngs=rngs)\n", + "\n", + " def __call__(self, x) -> jax.Array:\n", + " return nnx.relu(self.drop(self.linear(x)))\n", + "\n", + "model = Model(nnx.Rngs(params=0, dropout=1))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 16, 10)\n" + ] + } + ], + "source": [ + "num_devices = jax.local_device_count()\n", + "x = jnp.ones((num_devices, 16, 20))\n", + "state_axes = nnx.StateAxes({'dropout': 0, ...: None})\n", + "\n", + "@nnx.split_rngs(splits=num_devices, only='dropout')\n", + "@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)\n", + "def forward(model: Model, x: jnp.ndarray):\n", + " return model(x)\n", + "\n", + "y = forward(model, x)\n", + "print(y.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Recurrent dropout" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "class Count(nnx.Variable): pass\n", + "\n", + "class RNNCell(nnx.Module):\n", + " def __init__(self, din, dout, rngs):\n", + " self.linear = nnx.Linear(dout + din, dout, rngs=rngs)\n", + " self.drop = nnx.Dropout(0.1, rngs=rngs)\n", + " self.dout = dout\n", + " self.count = Count(jnp.array(0, jnp.uint32))\n", + "\n", + " def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]:\n", + " carry = self.drop(carry) # recurrent dropout\n", + " x = nnx.relu(self.linear(jnp.concatenate([carry, x], axis=-1)))\n", + " self.count += 1\n", + " return x, x\n", + "\n", + " def initial_state(self, batch_size: int):\n", + " return jnp.zeros((batch_size, self.dout))\n", + "\n", + "cell = RNNCell(8, 16, nnx.Rngs(params=0, dropout=1))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (4, 20, 16)\n", + "cell.count.value = Array(20, dtype=uint32)\n" + ] + } + ], + "source": [ + "@nnx.jit\n", + "def rnn_forward(cell: RNNCell, x: jax.Array):\n", + " carry = cell.initial_state(batch_size=x.shape[0])\n", + "\n", + " # broadcast 'dropout' RNG state to have the same mask on every step\n", + " state_axes = nnx.StateAxes({'dropout': None, ...: nnx.Carry})\n", + " @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))\n", + " def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]:\n", + " return cell(carry, x)\n", + "\n", + " _, y = unroll(cell, carry, x)\n", + " return y\n", + "\n", + "x = jnp.ones((4, 20, 8))\n", + "y = rnn_forward(cell, x)\n", + "\n", + "print(f'{y.shape = }')\n", + "print(f'{cell.count.value = }')" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs_nnx/guides/randomness.md b/docs_nnx/guides/randomness.md new file mode 100644 index 0000000000..37aa28016d --- /dev/null +++ b/docs_nnx/guides/randomness.md @@ -0,0 +1,214 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Randomness + +* intro: randomness as state +* Rngs, RngStream, and RngState + * Bullet point list explaining each + * Example of nnx.display to show the structure + * Generating keys + * stream names and default stream +* Filtering random state +* Reseeding +* Splitting Rngs +* Transforms + ++++ + +# Introduction + +```{code-cell} ipython3 +from flax import nnx +import jax +from jax import random, numpy as jnp +``` + +## Rngs, RngStream, and RngState + +* **Rngs**: +* **RngStream**: +* **RngState**: + * **RngKey**: + * **RngCount**: + +```{code-cell} ipython3 +rngs = nnx.Rngs(params=0, dropout=1) +nnx.display(rngs) +``` + +```{code-cell} ipython3 +params_key = rngs.params() +dropout_key = rngs.dropout() + +nnx.display(rngs) +``` + +### Default stream + +```{code-cell} ipython3 +rngs = nnx.Rngs(0, params=1) + +key1 = rngs.params() # call params +key2 = rngs.dropout() # fallback to default +key3 = rngs() # call default directly + +nnx.display(rngs) +``` + +## Filtering random state + +```{code-cell} ipython3 +class Foo(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.rngs = rngs # could be nested inside other Modules + +model = Foo(nnx.Rngs(params=0, dropout=1)) + +rng_state = nnx.state(model, nnx.RngState) # all random state +key_state = nnx.state(model, nnx.RngKey) # only keys +count_state = nnx.state(model, nnx.RngCount) # only counts +rng_params_state = nnx.state(model, 'params') # only params +rng_dropout_state = nnx.state(model, 'dropout') # only dropout +params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys + +nnx.display(params_key_state) +``` + +## Reseeding + +```{code-cell} ipython3 +class Model(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(20, 10, rngs=rngs) + self.drop = nnx.Dropout(0.1, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.drop(self.linear(x))) + +block = Model(nnx.Rngs(params=0, dropout=1)) +x = jnp.ones((1, 20)) + +y1 = block(x) +y2 = block(x) + +nnx.reseed(block, dropout=1) # reset dropout RngState +y3 = block(x) + +assert not jnp.allclose(y1, y2) # different +assert jnp.allclose(y1, y3) # same +``` + +## Splitting Rngs + +```{code-cell} ipython3 +split_key = random.split(random.key(0), 5) +rngs = nnx.Rngs(params=0, dropout=split_key) +nnx.display(rngs) +``` + +```{code-cell} ipython3 +try: + rngs.dropout() +except ValueError as e: + print(e) +``` + +```{code-cell} ipython3 +rngs = nnx.Rngs(params=0, dropout=1) + +@nnx.split_rngs(splits=5, only='dropout') +def f(rngs: nnx.Rngs): + print('Inside:') + nnx.display(rngs) + # rngs.dropout() # ValueError: fold_in accepts a single key... + +f(rngs) +``` + +```{code-cell} ipython3 +print('Outside:') +nnx.display(rngs) +``` + +## Transforms + +### Data-parallel dropout + +```{code-cell} ipython3 +class Model(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(20, 10, rngs=rngs) + self.drop = nnx.Dropout(0.1, rngs=rngs) + + def __call__(self, x) -> jax.Array: + return nnx.relu(self.drop(self.linear(x))) + +model = Model(nnx.Rngs(params=0, dropout=1)) +``` + +```{code-cell} ipython3 +num_devices = jax.local_device_count() +x = jnp.ones((num_devices, 16, 20)) +state_axes = nnx.StateAxes({'dropout': 0, ...: None}) + +@nnx.split_rngs(splits=num_devices, only='dropout') +@nnx.pmap(in_axes=(state_axes, 0), out_axes=0) +def forward(model: Model, x: jnp.ndarray): + return model(x) + +y = forward(model, x) +print(y.shape) +``` + +### Recurrent dropout + +```{code-cell} ipython3 +class Count(nnx.Variable): pass + +class RNNCell(nnx.Module): + def __init__(self, din, dout, rngs): + self.linear = nnx.Linear(dout + din, dout, rngs=rngs) + self.drop = nnx.Dropout(0.1, rngs=rngs) + self.dout = dout + self.count = Count(jnp.array(0, jnp.uint32)) + + def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]: + carry = self.drop(carry) # recurrent dropout + x = nnx.relu(self.linear(jnp.concatenate([carry, x], axis=-1))) + self.count += 1 + return x, x + + def initial_state(self, batch_size: int): + return jnp.zeros((batch_size, self.dout)) + +cell = RNNCell(8, 16, nnx.Rngs(params=0, dropout=1)) +``` + +```{code-cell} ipython3 +@nnx.jit +def rnn_forward(cell: RNNCell, x: jax.Array): + carry = cell.initial_state(batch_size=x.shape[0]) + + # broadcast 'dropout' RNG state to have the same mask on every step + state_axes = nnx.StateAxes({'dropout': None, ...: nnx.Carry}) + @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1)) + def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]: + return cell(carry, x) + + _, y = unroll(cell, carry, x) + return y + +x = jnp.ones((4, 20, 8)) +y = rnn_forward(cell, x) + +print(f'{y.shape = }') +print(f'{cell.count.value = }') +``` diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index c169a91fa1..4b3a40b842 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -901,9 +901,6 @@ def _scan_merge_out( broadcast_states = deque[State]() if isinstance(prefix, StateAxes): vectorized_states = deque(x.states) - assert len(prefix.axes) == len(vectorized_states) + len( - carry_states - ) + len(broadcast_states) for axis in prefix.axes: if isinstance(axis, int): state = vectorized_states.popleft() @@ -913,7 +910,10 @@ def _scan_merge_out( states.append(broadcast_states.popleft()) else: # axis is Carry states.append(carry_states.popleft()) - assert not vectorized_states and not carry_states and not broadcast_states + assert not carry_states and not broadcast_states + assert not vectorized_states or ( + len(vectorized_states) == 1 and not vectorized_states[0] + ) elif isinstance(prefix, int): state = jax.tree.map(lambda x: jnp.moveaxis(x, 0, prefix), x.state) states.extend((state, *carry_states, *broadcast_states)) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 824e7b6b0e..b91854dcaa 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -1614,6 +1614,41 @@ def f(_, rngs: nnx.Rngs): assert jnp.equal(dropout_keys[0], dropout_keys[1]) assert jnp.equal(dropout_keys[1], dropout_keys[2]) + def test_rnn_example(self): + class RNNCell(nnx.Module): + def __init__(self, input_size, hidden_size, rngs): + self.linear = nnx.Linear( + hidden_size + input_size, hidden_size, rngs=rngs + ) + self.drop = nnx.Dropout(0.1, rngs=rngs) + self.hidden_size = hidden_size + + def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]: + carry = self.drop(carry) # recurrent dropout + x = nnx.relu(self.linear(jnp.concatenate([carry, x], axis=-1))) + return x, x + + def initial_state(self, batch_size: int): + return jnp.zeros((batch_size, self.hidden_size)) + + cell = RNNCell(20, 20, nnx.Rngs(params=0, dropout=1)) + + state_axes = nnx.StateAxes({'dropout': None, ...: nnx.Carry}) + + def rnn_forward(cell: RNNCell, x: jax.Array): + carry = cell.initial_state(x.shape[0]) + + @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1)) + def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]: + return cell(carry, x) + + _, y = unroll(cell, carry, x) + return y + + x = jnp.ones((16, 10, 20)) + y = rnn_forward(cell, x) + print(y.shape) + class TestRemat(absltest.TestCase): def test_basic_remat(self):