-
Notifications
You must be signed in to change notification settings - Fork 631
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
445 additions
and
93 deletions.
There are no files selected for viewing
37 changes: 37 additions & 0 deletions
37
docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_linen.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
class Linear(nn.Module): | ||
din: int | ||
dout: int | ||
|
||
def setup(self): | ||
din, dout = self.din, self.dout | ||
key = self.make_rng("params") if self.is_initializing() else None | ||
self.w = self.variable("params", "w", jax.random.uniform, key, (din, dout)) | ||
self.b = self.variable("params", "b", jnp.zeros, (dout,)) | ||
self.count = self.variable("counts", "count", lambda: 0) | ||
|
||
def __call__(self, x) -> jax.Array: | ||
if not self.is_initializing(): | ||
self.count.value += 1 | ||
return x @ self.w.value + self.b.value | ||
|
||
|
||
model = Linear(din=5, dout=2) | ||
x = jnp.ones((1, 5)) | ||
variables = model.init(jax.random.PRNGKey(0), x) | ||
params, counts = variables["params"], variables["counts"] | ||
y, updates = model.apply( | ||
{"params": params, "counts": counts}, x, mutable=["counts"] | ||
) | ||
counts = updates["counts"] | ||
|
||
print("\n Linen") | ||
bounded_model = model.bind({"params": params, "counts": counts}) | ||
print(f"{bounded_model.count.value = }") | ||
print(f"{bounded_model.w.value = }") | ||
print(f"{bounded_model.b.value = }") | ||
print(f"{bounded_model = }") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from flax.experimental import nnx | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
class Count(nnx.Variable): | ||
pass | ||
|
||
|
||
class Linear(nnx.Module): | ||
|
||
def __init__(self, din: int, dout: int, *, ctx: nnx.Context): | ||
self.din = din | ||
self.dout = dout | ||
key = ctx.make_rng("params") | ||
self.w = nnx.Param(jax.random.uniform(key, (din, dout))) | ||
self.b = nnx.Param(jnp.zeros((dout,))) | ||
self.count = Count(0) # track the number of calls | ||
|
||
def __call__(self, x) -> jax.Array: | ||
self.count += 1 | ||
return x @ self.w + self.b | ||
|
||
|
||
model = Linear(din=5, dout=2, ctx=nnx.context(0)) | ||
x = jnp.ones((1, 5)) | ||
y = model(x) | ||
|
||
print("\n NNX") | ||
print(f"{model.count = }") | ||
print(f"{model.w = }") | ||
print(f"{model.b = }") | ||
print(f"{model = }") |
55 changes: 55 additions & 0 deletions
55
docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_linen.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from ex1_linen import Linear, nn, jnp, jax | ||
|
||
# ---------------- | ||
# example begin | ||
# ---------------- | ||
import numpy as np | ||
|
||
X = np.random.uniform(size=(1000, 1)) | ||
Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1)) | ||
|
||
model = Linear(din=1, dout=1) | ||
x = jnp.ones((1, 1)) | ||
variables = model.init(jax.random.PRNGKey(0), x) | ||
params, counts = variables["params"], variables["counts"] | ||
|
||
|
||
@jax.jit | ||
def train_step(params, counts, x, y): | ||
def loss_fn(params): | ||
y_pred, updates = model.apply( | ||
{"params": params, "counts": counts}, x, mutable=["counts"] | ||
) | ||
loss = jnp.mean((y_pred - y) ** 2) | ||
return loss, updates["counts"] | ||
|
||
grads, counts = jax.grad(loss_fn, has_aux=True)(params) | ||
params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads) | ||
return params, counts | ||
|
||
|
||
@jax.jit | ||
def eval_step(params, counts, x, y): | ||
y_pred, updates = model.apply( | ||
{"params": params, "counts": counts}, x, mutable=["counts"] | ||
) | ||
loss = jnp.mean((y_pred - y) ** 2) | ||
return loss, updates["counts"] | ||
|
||
|
||
for step in range(501): | ||
idx = np.random.randint(0, 1000, size=(32,)) | ||
x, y = X[idx], Y[idx] | ||
|
||
params, counts = train_step(params, counts, x, y) | ||
|
||
if step % 100 == 0: | ||
loss, counts = eval_step(params, counts, X, Y) | ||
|
||
print(f"Step {step}: loss={loss:.4f}") | ||
|
||
|
||
bounded_model = model.bind({"params": params, "counts": counts}) | ||
print(f"\n{bounded_model.w.value = }") | ||
print(f"{bounded_model.b.value = }") | ||
print(f"{bounded_model.count.value = }") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from ex1_nnx import Linear, nnx, jnp, jax | ||
|
||
# ---------------- | ||
# example begin | ||
# ---------------- | ||
import numpy as np | ||
|
||
X = np.random.uniform(size=(1000, 1)) | ||
Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1)) | ||
|
||
model = Linear(1, 1, ctx=nnx.context(0)) | ||
|
||
|
||
@nnx.jit | ||
def train_step(model: Linear, x, y): | ||
def loss_fn(model: Linear): | ||
y_pred = model(x) | ||
return jnp.mean((y_pred - y) ** 2) | ||
|
||
grads = nnx.grad(loss_fn, wrt=nnx.Param)(model) | ||
params = model.filter(nnx.Param) | ||
params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads) | ||
model.update_state(params) | ||
|
||
|
||
@nnx.jit | ||
def eval_step(model: Linear, x, y): | ||
y_pred = model(x) | ||
return jnp.mean((y_pred - y) ** 2) | ||
|
||
|
||
for step in range(501): | ||
idx = np.random.randint(0, 1000, size=(32,)) | ||
x, y = X[idx], Y[idx] | ||
|
||
train_step(model, x, y) | ||
|
||
if step % 100 == 0: | ||
loss = eval_step(model, X, Y) | ||
|
||
print(f"Step {step}: loss={loss:.4f}") | ||
|
||
print(f"\n{model.w = }") | ||
print(f"{model.b = }") | ||
print(f"{model.count = }") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from flax.experimental import nnx | ||
import jax | ||
|
||
|
||
class Block(nnx.Module): | ||
|
||
def __init__(self, dim: int, *, ctx: nnx.Context): | ||
self.linear = nnx.Linear(dim, dim, ctx=ctx) | ||
self.bn = nnx.BatchNorm(dim, ctx=ctx) | ||
self.dropout = nnx.Dropout(0.5) | ||
|
||
def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: | ||
x = self.linear(x) | ||
x = self.bn(x, ctx=ctx) | ||
x = self.dropout(x, ctx=ctx) | ||
x = jax.nn.gelu(x) | ||
return x | ||
|
||
|
||
from functools import partial | ||
|
||
|
||
class ScanMLP(nnx.Module): | ||
|
||
def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context): | ||
self.n_layers = n_layers | ||
# partition Context and split the `params` key | ||
keys, ctxdef = ctx.partition() | ||
params_key = jax.random.split(keys["params"], n_layers) | ||
|
||
@partial(jax.vmap, out_axes=((0, None), None)) | ||
def create_block(params_key): | ||
# merge back Context using the sliced `params` key | ||
ctx = ctxdef.merge({"params": params_key}) | ||
# create Block instance and return its partitions | ||
return Block(dim, ctx=ctx).partition(nnx.Param, nnx.BatchStat) | ||
|
||
# call vmap over create_block, passing the split `params` key | ||
(params, batch_stats), moduledef = create_block(params_key) | ||
# merge to get a lifted Block instance | ||
self.layers = moduledef.merge(params, batch_stats) | ||
|
||
def __call__(self, x: jax.Array, *, ctx: nnx.Context): | ||
keys, ctxdef = ctx.partition() | ||
dropout_key = jax.random.split(keys["dropout"], self.n_layers) | ||
(params, batch_stats), moduledef = self.layers.partition( | ||
nnx.Param, nnx.BatchStat | ||
) | ||
|
||
def scan_fn( | ||
carry: tuple[jax.Array, nnx.State], inputs: tuple[nnx.State, jax.Array] | ||
): | ||
(x, batch_stats), (params, dropout_key) = carry, inputs | ||
module = moduledef.merge(params, batch_stats) | ||
x = module(x, ctx=ctxdef.merge({"dropout": dropout_key})) | ||
params, _ = module.partition(nnx.Param) | ||
return (x, batch_stats), params | ||
|
||
(x, batch_stats), params = jax.lax.scan( | ||
scan_fn, (x, batch_stats), (params, dropout_key) | ||
) | ||
self.layers.update_state((params, batch_stats)) | ||
return x |
Oops, something went wrong.