diff --git a/docs_nnx/guides/index.rst b/docs_nnx/guides/index.rst
index c47a1b9f46..0bb012983d 100644
--- a/docs_nnx/guides/index.rst
+++ b/docs_nnx/guides/index.rst
@@ -5,6 +5,7 @@ Guides
:maxdepth: 2
filters_guide
+ randomness
linen_to_nnx
bridge_guide
surgery
diff --git a/docs_nnx/guides/randomness.ipynb b/docs_nnx/guides/randomness.ipynb
new file mode 100644
index 0000000000..a14b4ea806
--- /dev/null
+++ b/docs_nnx/guides/randomness.ipynb
@@ -0,0 +1,555 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Randomness\n",
+ "\n",
+ "* intro: randomness as state\n",
+ "* Rngs, RngStream, and RngState\n",
+ " * Bullet point list explaining each\n",
+ " * Example of nnx.display to show the structure\n",
+ " * Generating keys\n",
+ " * stream names and default stream\n",
+ "* Filtering random state\n",
+ "* Reseeding\n",
+ "* Splitting Rngs\n",
+ "* Transforms"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Introduction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from flax import nnx\n",
+ "import jax\n",
+ "from jax import random, numpy as jnp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Rngs, RngStream, and RngState\n",
+ "\n",
+ "* **Rngs**:\n",
+ "* **RngStream**:\n",
+ "* **RngState**:\n",
+ " * **RngKey**:\n",
+ " * **RngCount**:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "rngs = nnx.Rngs(params=0, dropout=1)\n",
+ "nnx.display(rngs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "params_key = rngs.params()\n",
+ "dropout_key = rngs.dropout()\n",
+ "\n",
+ "nnx.display(rngs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Default stream"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "rngs = nnx.Rngs(0, params=1)\n",
+ "\n",
+ "key1 = rngs.params() # call params\n",
+ "key2 = rngs.dropout() # fallback to default\n",
+ "key3 = rngs() # call default directly\n",
+ "\n",
+ "nnx.display(rngs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Filtering random state"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "class Foo(nnx.Module):\n",
+ " def __init__(self, rngs: nnx.Rngs):\n",
+ " self.rngs = rngs # could be nested inside other Modules\n",
+ "\n",
+ "model = Foo(nnx.Rngs(params=0, dropout=1))\n",
+ "\n",
+ "rng_state = nnx.state(model, nnx.RngState) # all random state\n",
+ "key_state = nnx.state(model, nnx.RngKey) # only keys\n",
+ "count_state = nnx.state(model, nnx.RngCount) # only counts\n",
+ "rng_params_state = nnx.state(model, 'params') # only params\n",
+ "rng_dropout_state = nnx.state(model, 'dropout') # only dropout\n",
+ "params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys\n",
+ "\n",
+ "nnx.display(params_key_state)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Reseeding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Model(nnx.Module):\n",
+ " def __init__(self, rngs: nnx.Rngs):\n",
+ " self.linear = nnx.Linear(20, 10, rngs=rngs)\n",
+ " self.drop = nnx.Dropout(0.1, rngs=rngs)\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " return nnx.relu(self.drop(self.linear(x)))\n",
+ "\n",
+ "block = Model(nnx.Rngs(params=0, dropout=1))\n",
+ "x = jnp.ones((1, 20))\n",
+ "\n",
+ "y1 = block(x)\n",
+ "y2 = block(x)\n",
+ "\n",
+ "nnx.reseed(block, dropout=1) # reset dropout RngState\n",
+ "y3 = block(x)\n",
+ "\n",
+ "assert not jnp.allclose(y1, y2) # different\n",
+ "assert jnp.allclose(y1, y3) # same"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Splitting Rngs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "split_key = random.split(random.key(0), 5)\n",
+ "rngs = nnx.Rngs(params=0, dropout=split_key)\n",
+ "nnx.display(rngs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "fold_in accepts a single key, but was given a key array of shape (5,) != (). Use jax.vmap for batching.\n"
+ ]
+ }
+ ],
+ "source": [
+ "try:\n",
+ " rngs.dropout()\n",
+ "except ValueError as e:\n",
+ " print(e)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Inside:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "rngs = nnx.Rngs(params=0, dropout=1)\n",
+ "\n",
+ "@nnx.split_rngs(splits=5, only='dropout')\n",
+ "def f(rngs: nnx.Rngs):\n",
+ " print('Inside:')\n",
+ " nnx.display(rngs)\n",
+ " # rngs.dropout() # ValueError: fold_in accepts a single key...\n",
+ "\n",
+ "f(rngs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Outside:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "print('Outside:')\n",
+ "nnx.display(rngs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Transforms\n",
+ "\n",
+ "### Data-parallel dropout"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Model(nnx.Module):\n",
+ " def __init__(self, rngs: nnx.Rngs):\n",
+ " self.linear = nnx.Linear(20, 10, rngs=rngs)\n",
+ " self.drop = nnx.Dropout(0.1, rngs=rngs)\n",
+ "\n",
+ " def __call__(self, x) -> jax.Array:\n",
+ " return nnx.relu(self.drop(self.linear(x)))\n",
+ "\n",
+ "model = Model(nnx.Rngs(params=0, dropout=1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1, 16, 10)\n"
+ ]
+ }
+ ],
+ "source": [
+ "num_devices = jax.local_device_count()\n",
+ "x = jnp.ones((num_devices, 16, 20))\n",
+ "state_axes = nnx.StateAxes({'dropout': 0, ...: None})\n",
+ "\n",
+ "@nnx.split_rngs(splits=num_devices, only='dropout')\n",
+ "@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)\n",
+ "def forward(model: Model, x: jnp.ndarray):\n",
+ " return model(x)\n",
+ "\n",
+ "y = forward(model, x)\n",
+ "print(y.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Recurrent dropout"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Count(nnx.Variable): pass\n",
+ "\n",
+ "class RNNCell(nnx.Module):\n",
+ " def __init__(self, din, dout, rngs):\n",
+ " self.linear = nnx.Linear(dout + din, dout, rngs=rngs)\n",
+ " self.drop = nnx.Dropout(0.1, rngs=rngs)\n",
+ " self.dout = dout\n",
+ " self.count = Count(jnp.array(0, jnp.uint32))\n",
+ "\n",
+ " def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]:\n",
+ " carry = self.drop(carry) # recurrent dropout\n",
+ " x = nnx.relu(self.linear(jnp.concatenate([carry, x], axis=-1)))\n",
+ " self.count += 1\n",
+ " return x, x\n",
+ "\n",
+ " def initial_state(self, batch_size: int):\n",
+ " return jnp.zeros((batch_size, self.dout))\n",
+ "\n",
+ "cell = RNNCell(8, 16, nnx.Rngs(params=0, dropout=1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "y.shape = (4, 20, 16)\n",
+ "cell.count.value = Array(20, dtype=uint32)\n"
+ ]
+ }
+ ],
+ "source": [
+ "@nnx.jit\n",
+ "def rnn_forward(cell: RNNCell, x: jax.Array):\n",
+ " carry = cell.initial_state(batch_size=x.shape[0])\n",
+ "\n",
+ " # broadcast 'dropout' RNG state to have the same mask on every step\n",
+ " state_axes = nnx.StateAxes({'dropout': None, ...: nnx.Carry})\n",
+ " @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))\n",
+ " def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]:\n",
+ " return cell(carry, x)\n",
+ "\n",
+ " _, y = unroll(cell, carry, x)\n",
+ " return y\n",
+ "\n",
+ "x = jnp.ones((4, 20, 8))\n",
+ "y = rnn_forward(cell, x)\n",
+ "\n",
+ "print(f'{y.shape = }')\n",
+ "print(f'{cell.count.value = }')"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "formats": "ipynb,md:myst"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs_nnx/guides/randomness.md b/docs_nnx/guides/randomness.md
new file mode 100644
index 0000000000..37aa28016d
--- /dev/null
+++ b/docs_nnx/guides/randomness.md
@@ -0,0 +1,214 @@
+---
+jupytext:
+ formats: ipynb,md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.13.8
+---
+
+# Randomness
+
+* intro: randomness as state
+* Rngs, RngStream, and RngState
+ * Bullet point list explaining each
+ * Example of nnx.display to show the structure
+ * Generating keys
+ * stream names and default stream
+* Filtering random state
+* Reseeding
+* Splitting Rngs
+* Transforms
+
++++
+
+# Introduction
+
+```{code-cell} ipython3
+from flax import nnx
+import jax
+from jax import random, numpy as jnp
+```
+
+## Rngs, RngStream, and RngState
+
+* **Rngs**:
+* **RngStream**:
+* **RngState**:
+ * **RngKey**:
+ * **RngCount**:
+
+```{code-cell} ipython3
+rngs = nnx.Rngs(params=0, dropout=1)
+nnx.display(rngs)
+```
+
+```{code-cell} ipython3
+params_key = rngs.params()
+dropout_key = rngs.dropout()
+
+nnx.display(rngs)
+```
+
+### Default stream
+
+```{code-cell} ipython3
+rngs = nnx.Rngs(0, params=1)
+
+key1 = rngs.params() # call params
+key2 = rngs.dropout() # fallback to default
+key3 = rngs() # call default directly
+
+nnx.display(rngs)
+```
+
+## Filtering random state
+
+```{code-cell} ipython3
+class Foo(nnx.Module):
+ def __init__(self, rngs: nnx.Rngs):
+ self.rngs = rngs # could be nested inside other Modules
+
+model = Foo(nnx.Rngs(params=0, dropout=1))
+
+rng_state = nnx.state(model, nnx.RngState) # all random state
+key_state = nnx.state(model, nnx.RngKey) # only keys
+count_state = nnx.state(model, nnx.RngCount) # only counts
+rng_params_state = nnx.state(model, 'params') # only params
+rng_dropout_state = nnx.state(model, 'dropout') # only dropout
+params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys
+
+nnx.display(params_key_state)
+```
+
+## Reseeding
+
+```{code-cell} ipython3
+class Model(nnx.Module):
+ def __init__(self, rngs: nnx.Rngs):
+ self.linear = nnx.Linear(20, 10, rngs=rngs)
+ self.drop = nnx.Dropout(0.1, rngs=rngs)
+
+ def __call__(self, x):
+ return nnx.relu(self.drop(self.linear(x)))
+
+block = Model(nnx.Rngs(params=0, dropout=1))
+x = jnp.ones((1, 20))
+
+y1 = block(x)
+y2 = block(x)
+
+nnx.reseed(block, dropout=1) # reset dropout RngState
+y3 = block(x)
+
+assert not jnp.allclose(y1, y2) # different
+assert jnp.allclose(y1, y3) # same
+```
+
+## Splitting Rngs
+
+```{code-cell} ipython3
+split_key = random.split(random.key(0), 5)
+rngs = nnx.Rngs(params=0, dropout=split_key)
+nnx.display(rngs)
+```
+
+```{code-cell} ipython3
+try:
+ rngs.dropout()
+except ValueError as e:
+ print(e)
+```
+
+```{code-cell} ipython3
+rngs = nnx.Rngs(params=0, dropout=1)
+
+@nnx.split_rngs(splits=5, only='dropout')
+def f(rngs: nnx.Rngs):
+ print('Inside:')
+ nnx.display(rngs)
+ # rngs.dropout() # ValueError: fold_in accepts a single key...
+
+f(rngs)
+```
+
+```{code-cell} ipython3
+print('Outside:')
+nnx.display(rngs)
+```
+
+## Transforms
+
+### Data-parallel dropout
+
+```{code-cell} ipython3
+class Model(nnx.Module):
+ def __init__(self, rngs: nnx.Rngs):
+ self.linear = nnx.Linear(20, 10, rngs=rngs)
+ self.drop = nnx.Dropout(0.1, rngs=rngs)
+
+ def __call__(self, x) -> jax.Array:
+ return nnx.relu(self.drop(self.linear(x)))
+
+model = Model(nnx.Rngs(params=0, dropout=1))
+```
+
+```{code-cell} ipython3
+num_devices = jax.local_device_count()
+x = jnp.ones((num_devices, 16, 20))
+state_axes = nnx.StateAxes({'dropout': 0, ...: None})
+
+@nnx.split_rngs(splits=num_devices, only='dropout')
+@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
+def forward(model: Model, x: jnp.ndarray):
+ return model(x)
+
+y = forward(model, x)
+print(y.shape)
+```
+
+### Recurrent dropout
+
+```{code-cell} ipython3
+class Count(nnx.Variable): pass
+
+class RNNCell(nnx.Module):
+ def __init__(self, din, dout, rngs):
+ self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
+ self.drop = nnx.Dropout(0.1, rngs=rngs)
+ self.dout = dout
+ self.count = Count(jnp.array(0, jnp.uint32))
+
+ def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]:
+ carry = self.drop(carry) # recurrent dropout
+ x = nnx.relu(self.linear(jnp.concatenate([carry, x], axis=-1)))
+ self.count += 1
+ return x, x
+
+ def initial_state(self, batch_size: int):
+ return jnp.zeros((batch_size, self.dout))
+
+cell = RNNCell(8, 16, nnx.Rngs(params=0, dropout=1))
+```
+
+```{code-cell} ipython3
+@nnx.jit
+def rnn_forward(cell: RNNCell, x: jax.Array):
+ carry = cell.initial_state(batch_size=x.shape[0])
+
+ # broadcast 'dropout' RNG state to have the same mask on every step
+ state_axes = nnx.StateAxes({'dropout': None, ...: nnx.Carry})
+ @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
+ def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]:
+ return cell(carry, x)
+
+ _, y = unroll(cell, carry, x)
+ return y
+
+x = jnp.ones((4, 20, 8))
+y = rnn_forward(cell, x)
+
+print(f'{y.shape = }')
+print(f'{cell.count.value = }')
+```
diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py
index c169a91fa1..4b3a40b842 100644
--- a/flax/nnx/transforms/iteration.py
+++ b/flax/nnx/transforms/iteration.py
@@ -901,9 +901,6 @@ def _scan_merge_out(
broadcast_states = deque[State]()
if isinstance(prefix, StateAxes):
vectorized_states = deque(x.states)
- assert len(prefix.axes) == len(vectorized_states) + len(
- carry_states
- ) + len(broadcast_states)
for axis in prefix.axes:
if isinstance(axis, int):
state = vectorized_states.popleft()
@@ -913,7 +910,10 @@ def _scan_merge_out(
states.append(broadcast_states.popleft())
else: # axis is Carry
states.append(carry_states.popleft())
- assert not vectorized_states and not carry_states and not broadcast_states
+ assert not carry_states and not broadcast_states
+ assert not vectorized_states or (
+ len(vectorized_states) == 1 and not vectorized_states[0]
+ )
elif isinstance(prefix, int):
state = jax.tree.map(lambda x: jnp.moveaxis(x, 0, prefix), x.state)
states.extend((state, *carry_states, *broadcast_states))
diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py
index 824e7b6b0e..b91854dcaa 100644
--- a/tests/nnx/transforms_test.py
+++ b/tests/nnx/transforms_test.py
@@ -1614,6 +1614,41 @@ def f(_, rngs: nnx.Rngs):
assert jnp.equal(dropout_keys[0], dropout_keys[1])
assert jnp.equal(dropout_keys[1], dropout_keys[2])
+ def test_rnn_example(self):
+ class RNNCell(nnx.Module):
+ def __init__(self, input_size, hidden_size, rngs):
+ self.linear = nnx.Linear(
+ hidden_size + input_size, hidden_size, rngs=rngs
+ )
+ self.drop = nnx.Dropout(0.1, rngs=rngs)
+ self.hidden_size = hidden_size
+
+ def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]:
+ carry = self.drop(carry) # recurrent dropout
+ x = nnx.relu(self.linear(jnp.concatenate([carry, x], axis=-1)))
+ return x, x
+
+ def initial_state(self, batch_size: int):
+ return jnp.zeros((batch_size, self.hidden_size))
+
+ cell = RNNCell(20, 20, nnx.Rngs(params=0, dropout=1))
+
+ state_axes = nnx.StateAxes({'dropout': None, ...: nnx.Carry})
+
+ def rnn_forward(cell: RNNCell, x: jax.Array):
+ carry = cell.initial_state(x.shape[0])
+
+ @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
+ def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]:
+ return cell(carry, x)
+
+ _, y = unroll(cell, carry, x)
+ return y
+
+ x = jnp.ones((16, 10, 20))
+ y = rnn_forward(cell, x)
+ print(y.shape)
+
class TestRemat(absltest.TestCase):
def test_basic_remat(self):