Skip to content

Commit

Permalink
add lifted transform example
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Aug 11, 2023
1 parent 44f5e72 commit 1d63c16
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 93 deletions.
37 changes: 37 additions & 0 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_linen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import flax.linen as nn
import jax
import jax.numpy as jnp


class Linear(nn.Module):
din: int
dout: int

def setup(self):
din, dout = self.din, self.dout
key = self.make_rng("params") if self.is_initializing() else None
self.w = self.variable("params", "w", jax.random.uniform, key, (din, dout))
self.b = self.variable("params", "b", jnp.zeros, (dout,))
self.count = self.variable("counts", "count", lambda: 0)

def __call__(self, x) -> jax.Array:
if not self.is_initializing():
self.count.value += 1
return x @ self.w.value + self.b.value


model = Linear(din=5, dout=2)
x = jnp.ones((1, 5))
variables = model.init(jax.random.PRNGKey(0), x)
params, counts = variables["params"], variables["counts"]
y, updates = model.apply(
{"params": params, "counts": counts}, x, mutable=["counts"]
)
counts = updates["counts"]

print("\n Linen")
bounded_model = model.bind({"params": params, "counts": counts})
print(f"{bounded_model.count.value = }")
print(f"{bounded_model.w.value = }")
print(f"{bounded_model.b.value = }")
print(f"{bounded_model = }")
33 changes: 33 additions & 0 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_nnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from flax.experimental import nnx
import jax
import jax.numpy as jnp


class Count(nnx.Variable):
pass


class Linear(nnx.Module):

def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
self.din = din
self.dout = dout
key = ctx.make_rng("params")
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(0) # track the number of calls

def __call__(self, x) -> jax.Array:
self.count += 1
return x @ self.w + self.b


model = Linear(din=5, dout=2, ctx=nnx.context(0))
x = jnp.ones((1, 5))
y = model(x)

print("\n NNX")
print(f"{model.count = }")
print(f"{model.w = }")
print(f"{model.b = }")
print(f"{model = }")
55 changes: 55 additions & 0 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_linen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from ex1_linen import Linear, nn, jnp, jax

# ----------------
# example begin
# ----------------
import numpy as np

X = np.random.uniform(size=(1000, 1))
Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1))

model = Linear(din=1, dout=1)
x = jnp.ones((1, 1))
variables = model.init(jax.random.PRNGKey(0), x)
params, counts = variables["params"], variables["counts"]


@jax.jit
def train_step(params, counts, x, y):
def loss_fn(params):
y_pred, updates = model.apply(
{"params": params, "counts": counts}, x, mutable=["counts"]
)
loss = jnp.mean((y_pred - y) ** 2)
return loss, updates["counts"]

grads, counts = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, counts


@jax.jit
def eval_step(params, counts, x, y):
y_pred, updates = model.apply(
{"params": params, "counts": counts}, x, mutable=["counts"]
)
loss = jnp.mean((y_pred - y) ** 2)
return loss, updates["counts"]


for step in range(501):
idx = np.random.randint(0, 1000, size=(32,))
x, y = X[idx], Y[idx]

params, counts = train_step(params, counts, x, y)

if step % 100 == 0:
loss, counts = eval_step(params, counts, X, Y)

print(f"Step {step}: loss={loss:.4f}")


bounded_model = model.bind({"params": params, "counts": counts})
print(f"\n{bounded_model.w.value = }")
print(f"{bounded_model.b.value = }")
print(f"{bounded_model.count.value = }")
45 changes: 45 additions & 0 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_nnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from ex1_nnx import Linear, nnx, jnp, jax

# ----------------
# example begin
# ----------------
import numpy as np

X = np.random.uniform(size=(1000, 1))
Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1))

model = Linear(1, 1, ctx=nnx.context(0))


@nnx.jit
def train_step(model: Linear, x, y):
def loss_fn(model: Linear):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)

grads = nnx.grad(loss_fn, wrt=nnx.Param)(model)
params = model.filter(nnx.Param)
params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads)
model.update_state(params)


@nnx.jit
def eval_step(model: Linear, x, y):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)


for step in range(501):
idx = np.random.randint(0, 1000, size=(32,))
x, y = X[idx], Y[idx]

train_step(model, x, y)

if step % 100 == 0:
loss = eval_step(model, X, Y)

print(f"Step {step}: loss={loss:.4f}")

print(f"\n{model.w = }")
print(f"{model.b = }")
print(f"{model.count = }")
63 changes: 63 additions & 0 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex3_nnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from flax.experimental import nnx
import jax


class Block(nnx.Module):

def __init__(self, dim: int, *, ctx: nnx.Context):
self.linear = nnx.Linear(dim, dim, ctx=ctx)
self.bn = nnx.BatchNorm(dim, ctx=ctx)
self.dropout = nnx.Dropout(0.5)

def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array:
x = self.linear(x)
x = self.bn(x, ctx=ctx)
x = self.dropout(x, ctx=ctx)
x = jax.nn.gelu(x)
return x


from functools import partial


class ScanMLP(nnx.Module):

def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context):
self.n_layers = n_layers
# partition Context and split the `params` key
keys, ctxdef = ctx.partition()
params_key = jax.random.split(keys["params"], n_layers)

@partial(jax.vmap, out_axes=((0, None), None))
def create_block(params_key):
# merge back Context using the sliced `params` key
ctx = ctxdef.merge({"params": params_key})
# create Block instance and return its partitions
return Block(dim, ctx=ctx).partition(nnx.Param, nnx.BatchStat)

# call vmap over create_block, passing the split `params` key
(params, batch_stats), moduledef = create_block(params_key)
# merge to get a lifted Block instance
self.layers = moduledef.merge(params, batch_stats)

def __call__(self, x: jax.Array, *, ctx: nnx.Context):
keys, ctxdef = ctx.partition()
dropout_key = jax.random.split(keys["dropout"], self.n_layers)
(params, batch_stats), moduledef = self.layers.partition(
nnx.Param, nnx.BatchStat
)

def scan_fn(
carry: tuple[jax.Array, nnx.State], inputs: tuple[nnx.State, jax.Array]
):
(x, batch_stats), (params, dropout_key) = carry, inputs
module = moduledef.merge(params, batch_stats)
x = module(x, ctx=ctxdef.merge({"dropout": dropout_key}))
params, _ = module.partition(nnx.Param)
return (x, batch_stats), params

(x, batch_stats), params = jax.lax.scan(
scan_fn, (x, batch_stats), (params, dropout_key)
)
self.layers.update_state((params, batch_stats))
return x
Loading

0 comments on commit 1d63c16

Please sign in to comment.