Skip to content

Commit

Permalink
Merge pull request #4223 from google:fix-transforms-guide
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678136334
  • Loading branch information
Flax Authors committed Sep 24, 2024
2 parents b2277ab + 357f166 commit e3772b2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 44 deletions.
44 changes: 22 additions & 22 deletions docs_nnx/guides/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
" def __init__(self, kernel: jax.Array, bias: jax.Array):\n",
" self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (10, 2, 3)),\n",
" bias=jnp.zeros((10, 3)),\n",
")\n",
Expand All @@ -101,10 +101,10 @@
" assert x.ndim == 1, 'Batch dimensions not allowed'\n",
" return x @ weights.kernel + weights.bias\n",
"\n",
"y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(self, x)\n",
"y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)\n",
"\n",
"print(f'{y.shape = }')\n",
"nnx.display(self)"
"nnx.display(weights)"
]
},
{
Expand Down Expand Up @@ -158,8 +158,8 @@
" )\n",
"\n",
"seeds = jnp.arange(10)\n",
"self = nnx.vmap(create_weights)(seeds)\n",
"nnx.display(self)"
"weights = nnx.vmap(create_weights)(seeds)\n",
"nnx.display(weights)"
]
},
{
Expand Down Expand Up @@ -276,7 +276,7 @@
" self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n",
" self.count = Count(count)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (10, 2, 3)),\n",
" bias=jnp.zeros((10, 3)),\n",
" count=jnp.arange(10),\n",
Expand All @@ -290,9 +290,9 @@
" return x @ weights.kernel + weights.bias\n",
"\n",
"\n",
"y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(self, x)\n",
"y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)\n",
"\n",
"self.count"
"weights.count"
]
},
{
Expand Down Expand Up @@ -353,7 +353,7 @@
" self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n",
" self.count = Count(count)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (10, 2, 3)),\n",
" bias=jnp.zeros((10, 3)),\n",
" count=jnp.arange(10),\n",
Expand All @@ -370,9 +370,9 @@
" weights.new_param = weights.kernel # share reference\n",
" return y\n",
"\n",
"y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(self, x)\n",
"y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)\n",
"\n",
"nnx.display(self)"
"nnx.display(weights)"
]
},
{
Expand Down Expand Up @@ -433,7 +433,7 @@
" self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n",
" self.count = Count(count)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (10, 2, 3)),\n",
" bias=jnp.zeros((10, 3)),\n",
" count=jnp.array(0),\n",
Expand All @@ -448,9 +448,9 @@
" return x @ weights.kernel + weights.bias\n",
"\n",
"state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count\n",
"y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(self, x)\n",
"y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)\n",
"\n",
"self.count"
"weights.count"
]
},
{
Expand Down Expand Up @@ -517,7 +517,7 @@
" self.count = Count(count)\n",
" self.rngs = nnx.Rngs(noise=seed)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (2, 3)),\n",
" bias=jnp.zeros((3,)),\n",
" count=jnp.array(0),\n",
Expand All @@ -533,11 +533,11 @@
" return y + random.normal(weights.rngs.noise(), y.shape)\n",
"\n",
"state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})\n",
"y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)\n",
"y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)\n",
"y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)\n",
"y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)\n",
"\n",
"print(jnp.allclose(y1, y2))\n",
"nnx.display(self)"
"nnx.display(weights)"
]
},
{
Expand Down Expand Up @@ -589,7 +589,7 @@
}
],
"source": [
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (2, 3)),\n",
" bias=jnp.zeros((3,)),\n",
" count=jnp.array(0),\n",
Expand All @@ -608,11 +608,11 @@
" y = x @ weights.kernel + weights.bias\n",
" return y + random.normal(weights.rngs.noise(), y.shape)\n",
"\n",
"y1 = noisy_vector_dot(self, x)\n",
"y2 = noisy_vector_dot(self, x)\n",
"y1 = noisy_vector_dot(weights, x)\n",
"y2 = noisy_vector_dot(weights, x)\n",
"\n",
"print(jnp.allclose(y1, y2))\n",
"nnx.display(self)"
"nnx.display(weights)"
]
},
{
Expand Down
44 changes: 22 additions & 22 deletions docs_nnx/guides/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
)
Expand All @@ -55,10 +55,10 @@ def vector_dot(weights: Weights, x: jax.Array):
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ weights.kernel + weights.bias
y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(self, x)
y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)
print(f'{y.shape = }')
nnx.display(self)
nnx.display(weights)
```

Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it where a Pytree of arrays. Prefix patterns are also allowed, `in_axes=(0, 0)` would've also worked in this case.
Expand All @@ -75,8 +75,8 @@ def create_weights(seed: jax.Array):
)
seeds = jnp.arange(10)
self = nnx.vmap(create_weights)(seeds)
nnx.display(self)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)
```

## Transforming Methods
Expand Down Expand Up @@ -120,7 +120,7 @@ class Weights(nnx.Module):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.arange(10),
Expand All @@ -134,9 +134,9 @@ def stateful_vector_dot(weights: Weights, x: jax.Array):
return x @ weights.kernel + weights.bias
y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(self, x)
y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)
self.count
weights.count
```

After running `stateful_vector_dot` once we verify that the `count` attribute was correctly updated. Because Weights is vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by 1 inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!
Expand All @@ -156,7 +156,7 @@ class Weights(nnx.Module):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.arange(10),
Expand All @@ -173,9 +173,9 @@ def crazy_vector_dot(weights: Weights, x: jax.Array):
weights.new_param = weights.kernel # share reference
return y
y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(self, x)
y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)
nnx.display(self)
nnx.display(weights)
```

> With great power comes great responsibility.
Expand Down Expand Up @@ -207,7 +207,7 @@ class Weights(nnx.Module):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.array(0),
Expand All @@ -222,9 +222,9 @@ def stateful_vector_dot(weights: Weights, x: jax.Array):
return x @ weights.kernel + weights.bias
state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(self, x)
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)
self.count
weights.count
```

Here count is now a scalar since its not being vectorized. Also, note that `StateAxes` can only be used directly on Flax NNX objects, it cannot be used as a prefix for a pytree of objects.
Expand All @@ -243,7 +243,7 @@ class Weights(nnx.Module):
self.count = Count(count)
self.rngs = nnx.Rngs(noise=seed)
self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (2, 3)),
bias=jnp.zeros((3,)),
count=jnp.array(0),
Expand All @@ -259,19 +259,19 @@ def noisy_vector_dot(weights: Weights, x: jax.Array):
return y + random.normal(weights.rngs.noise(), y.shape)
state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
print(jnp.allclose(y1, y2))
nnx.display(self)
nnx.display(weights)
```

Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called.

In the example above we manually split the random state during construction, this is fine as it makes the intention clear but it also doesn't let us use `Rngs` outside of `vmap` since its state is always split. To solve this we pass an unplit seed and use the `nnx.split_rngs` decorator before `vmap` to split the `RngState` right before each call to the function and then "lower" it back so its usable.

```{code-cell} ipython3
self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (2, 3)),
bias=jnp.zeros((3,)),
count=jnp.array(0),
Expand All @@ -290,11 +290,11 @@ def noisy_vector_dot(weights: Weights, x: jax.Array):
y = x @ weights.kernel + weights.bias
return y + random.normal(weights.rngs.noise(), y.shape)
y1 = noisy_vector_dot(self, x)
y2 = noisy_vector_dot(self, x)
y1 = noisy_vector_dot(weights, x)
y2 = noisy_vector_dot(weights, x)
print(jnp.allclose(y1, y2))
nnx.display(self)
nnx.display(weights)
```

## Consistent aliasing
Expand Down

0 comments on commit e3772b2

Please sign in to comment.