diff --git a/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_linen.py b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_linen.py new file mode 100644 index 0000000000..a5a2908be5 --- /dev/null +++ b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_linen.py @@ -0,0 +1,37 @@ +import flax.linen as nn +import jax +import jax.numpy as jnp + + +class Linear(nn.Module): + din: int + dout: int + + def setup(self): + din, dout = self.din, self.dout + key = self.make_rng("params") if self.is_initializing() else None + self.w = self.variable("params", "w", jax.random.uniform, key, (din, dout)) + self.b = self.variable("params", "b", jnp.zeros, (dout,)) + self.count = self.variable("counts", "count", lambda: 0) + + def __call__(self, x) -> jax.Array: + if not self.is_initializing(): + self.count.value += 1 + return x @ self.w.value + self.b.value + + +model = Linear(din=5, dout=2) +x = jnp.ones((1, 5)) +variables = model.init(jax.random.PRNGKey(0), x) +params, counts = variables["params"], variables["counts"] +y, updates = model.apply( + {"params": params, "counts": counts}, x, mutable=["counts"] +) +counts = updates["counts"] + +print("\n Linen") +bounded_model = model.bind({"params": params, "counts": counts}) +print(f"{bounded_model.count.value = }") +print(f"{bounded_model.w.value = }") +print(f"{bounded_model.b.value = }") +print(f"{bounded_model = }") diff --git a/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_nnx.py b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_nnx.py new file mode 100644 index 0000000000..5e04a38db7 --- /dev/null +++ b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_nnx.py @@ -0,0 +1,33 @@ +from flax.experimental import nnx +import jax +import jax.numpy as jnp + + +class Count(nnx.Variable): + pass + + +class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + self.din = din + self.dout = dout + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = Count(0) # track the number of calls + + def __call__(self, x) -> jax.Array: + self.count += 1 + return x @ self.w + self.b + + +model = Linear(din=5, dout=2, ctx=nnx.context(0)) +x = jnp.ones((1, 5)) +y = model(x) + +print("\n NNX") +print(f"{model.count = }") +print(f"{model.w = }") +print(f"{model.b = }") +print(f"{model = }") diff --git a/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_linen.py b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_linen.py new file mode 100644 index 0000000000..91344e7b48 --- /dev/null +++ b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_linen.py @@ -0,0 +1,55 @@ +from ex1_linen import Linear, nn, jnp, jax + +# ---------------- +# example begin +# ---------------- +import numpy as np + +X = np.random.uniform(size=(1000, 1)) +Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1)) + +model = Linear(din=1, dout=1) +x = jnp.ones((1, 1)) +variables = model.init(jax.random.PRNGKey(0), x) +params, counts = variables["params"], variables["counts"] + + +@jax.jit +def train_step(params, counts, x, y): + def loss_fn(params): + y_pred, updates = model.apply( + {"params": params, "counts": counts}, x, mutable=["counts"] + ) + loss = jnp.mean((y_pred - y) ** 2) + return loss, updates["counts"] + + grads, counts = jax.grad(loss_fn, has_aux=True)(params) + params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads) + return params, counts + + +@jax.jit +def eval_step(params, counts, x, y): + y_pred, updates = model.apply( + {"params": params, "counts": counts}, x, mutable=["counts"] + ) + loss = jnp.mean((y_pred - y) ** 2) + return loss, updates["counts"] + + +for step in range(501): + idx = np.random.randint(0, 1000, size=(32,)) + x, y = X[idx], Y[idx] + + params, counts = train_step(params, counts, x, y) + + if step % 100 == 0: + loss, counts = eval_step(params, counts, X, Y) + + print(f"Step {step}: loss={loss:.4f}") + + +bounded_model = model.bind({"params": params, "counts": counts}) +print(f"\n{bounded_model.w.value = }") +print(f"{bounded_model.b.value = }") +print(f"{bounded_model.count.value = }") diff --git a/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_nnx.py b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_nnx.py new file mode 100644 index 0000000000..a1c3ed8fde --- /dev/null +++ b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_nnx.py @@ -0,0 +1,45 @@ +from ex1_nnx import Linear, nnx, jnp, jax + +# ---------------- +# example begin +# ---------------- +import numpy as np + +X = np.random.uniform(size=(1000, 1)) +Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1)) + +model = Linear(1, 1, ctx=nnx.context(0)) + + +@nnx.jit +def train_step(model: Linear, x, y): + def loss_fn(model: Linear): + y_pred = model(x) + return jnp.mean((y_pred - y) ** 2) + + grads = nnx.grad(loss_fn, wrt=nnx.Param)(model) + params = model.filter(nnx.Param) + params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads) + model.update_state(params) + + +@nnx.jit +def eval_step(model: Linear, x, y): + y_pred = model(x) + return jnp.mean((y_pred - y) ** 2) + + +for step in range(501): + idx = np.random.randint(0, 1000, size=(32,)) + x, y = X[idx], Y[idx] + + train_step(model, x, y) + + if step % 100 == 0: + loss = eval_step(model, X, Y) + + print(f"Step {step}: loss={loss:.4f}") + +print(f"\n{model.w = }") +print(f"{model.b = }") +print(f"{model.count = }") diff --git a/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex3_nnx.py b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex3_nnx.py new file mode 100644 index 0000000000..2e530dc58f --- /dev/null +++ b/docs/experimental/nnx/[REMOVE]why_nnx_examples/ex3_nnx.py @@ -0,0 +1,63 @@ +from flax.experimental import nnx +import jax + + +class Block(nnx.Module): + + def __init__(self, dim: int, *, ctx: nnx.Context): + self.linear = nnx.Linear(dim, dim, ctx=ctx) + self.bn = nnx.BatchNorm(dim, ctx=ctx) + self.dropout = nnx.Dropout(0.5) + + def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: + x = self.linear(x) + x = self.bn(x, ctx=ctx) + x = self.dropout(x, ctx=ctx) + x = jax.nn.gelu(x) + return x + + +from functools import partial + + +class ScanMLP(nnx.Module): + + def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context): + self.n_layers = n_layers + # partition Context and split the `params` key + keys, ctxdef = ctx.partition() + params_key = jax.random.split(keys["params"], n_layers) + + @partial(jax.vmap, out_axes=((0, None), None)) + def create_block(params_key): + # merge back Context using the sliced `params` key + ctx = ctxdef.merge({"params": params_key}) + # create Block instance and return its partitions + return Block(dim, ctx=ctx).partition(nnx.Param, nnx.BatchStat) + + # call vmap over create_block, passing the split `params` key + (params, batch_stats), moduledef = create_block(params_key) + # merge to get a lifted Block instance + self.layers = moduledef.merge(params, batch_stats) + + def __call__(self, x: jax.Array, *, ctx: nnx.Context): + keys, ctxdef = ctx.partition() + dropout_key = jax.random.split(keys["dropout"], self.n_layers) + (params, batch_stats), moduledef = self.layers.partition( + nnx.Param, nnx.BatchStat + ) + + def scan_fn( + carry: tuple[jax.Array, nnx.State], inputs: tuple[nnx.State, jax.Array] + ): + (x, batch_stats), (params, dropout_key) = carry, inputs + module = moduledef.merge(params, batch_stats) + x = module(x, ctx=ctxdef.merge({"dropout": dropout_key})) + params, _ = module.partition(nnx.Param) + return (x, batch_stats), params + + (x, batch_stats), params = jax.lax.scan( + scan_fn, (x, batch_stats), (params, dropout_key) + ) + self.layers.update_state((params, batch_stats)) + return x diff --git a/docs/experimental/nnx/why_nnx.rst b/docs/experimental/nnx/why_nnx.rst index ac836a5c01..5b7180cd3c 100644 --- a/docs/experimental/nnx/why_nnx.rst +++ b/docs/experimental/nnx/why_nnx.rst @@ -13,6 +13,9 @@ However, Linen's power has come at a cost: Flax NNX is an attempt to keep the features that made Linen great while simplifying the API and making it more Pythonic. +## Add roadmap / what this means / set expectations. +* Mention its experimental, etc. + NNX is Pythonic --------------- @@ -69,7 +72,8 @@ NNX is Pythonic self.count = self.variable("counts", "count", lambda: 0) def __call__(self, x) -> jax.Array: - self.count.value += 1 + if not self.is_initializing(): + self.count.value += 1 return x @ self.w.value + self.b.value model = Linear(din=5, dout=2) @@ -107,7 +111,7 @@ NNX is Pythonic .. tab-item:: NNX :sync: NNX - .. code-block:: python + .. code-block:: console model.count = 1 model.w = Array([[0.0779959 , 0.8061936 ], @@ -125,7 +129,7 @@ NNX is Pythonic .. tab-item:: Linen :sync: Linen - .. code-block:: python + .. code-block:: console bounded_model.count.value = 1 bounded_model.w.value = Array([[0.76684463, 0.51083136], @@ -146,110 +150,237 @@ NNX is friendly for beginners * Example of training in eager mode -```python -import numpy as np +.. codediff:: + :title_left: NNX + :title_right: Linen + :sync: + + import numpy as np + + X = np.random.uniform(size=(1000, 1)) + Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1)) -X = np.random.uniform(size=(1000, 1)) -Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1)) + model = Linear(1, 1, ctx=nnx.context(0)) -model = Linear(1, 1, ctx=nnx.context(0)) + @nnx.jit + def train_step(model: Linear, x, y): + def loss_fn(model: Linear): + y_pred = model(x) + return jnp.mean((y_pred - y) ** 2) -for step in range(500): - idx = np.random.randint(0, 1000, size=(32,)) - x, y = X[idx], Y[idx] + grads = nnx.grad(loss_fn, wrt=nnx.Param)(model) + params = model.filter(nnx.Param) + params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads) + model.update_state(params) - def loss_fn(model: Linear): + @nnx.jit + def eval_step(model: Linear, x, y): y_pred = model(x) return jnp.mean((y_pred - y) ** 2) - loss, grads = nnx.value_and_grad(loss_fn, wrt=nnx.Param)(model) + for step in range(501): + idx = np.random.randint(0, 1000, size=(32,)) + x, y = X[idx], Y[idx] - params = model.filter(nnx.Param) - params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads) - model.update_state(params) + train_step(model, x, y) - if step % 100 == 0: - y_pred = model(X) - loss = np.mean((y_pred - Y) ** 2) - print(f"Step {step}: loss={loss:.4f}") + if step % 100 == 0: + loss = eval_step(model, X, Y) -print(f"\n{model.w = }") -print(f"{model.b = }") -``` + print(f"Step {step}: loss={loss:.4f}") + + print(f"\n{model.w = }") + print(f"{model.b = }") + print(f"{model.count = }") + + --- + + import numpy as np + + X = np.random.uniform(size=(1000, 1)) + Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1)) + + model = Linear(din=1, dout=1) + x = jnp.ones((1, 1)) + variables = model.init(jax.random.PRNGKey(0), x) + params, counts = variables["params"], variables["counts"] + + @jax.jit + def train_step(params, counts, x, y): + def loss_fn(params): + y_pred, updates = model.apply( + {"params": params, "counts": counts}, x, mutable=["counts"] + ) + loss = jnp.mean((y_pred - y) ** 2) + return loss, updates["counts"] + + grads, counts = jax.grad(loss_fn, has_aux=True)(params) + params = jax.tree_map(lambda p, g: p - 0.1 * g, params, grads) + return params, counts + + @jax.jit + def eval_step(params, counts, x, y): + y_pred, updates = model.apply( + {"params": params, "counts": counts}, x, mutable=["counts"] + ) + loss = jnp.mean((y_pred - y) ** 2) + return loss, updates["counts"] + + for step in range(501): + idx = np.random.randint(0, 1000, size=(32,)) + x, y = X[idx], Y[idx] + + params, counts = train_step(params, counts, x, y) + + if step % 100 == 0: + loss, counts = eval_step(params, counts, X, Y) + + print(f"Step {step}: loss={loss:.4f}") + + bounded_model = model.bind({"params": params, "counts": counts}) + print(f"\n{bounded_model.w.value = }") + print(f"{bounded_model.b.value = }") + print(f"{bounded_model.count.value = }") + +**Output:** + +.. tab-set:: + + .. tab-item:: NNX + :sync: NNX + .. code-block:: console + + Step 0: loss=0.2701 + Step 100: loss=0.0100 + Step 200: loss=0.0097 + Step 300: loss=0.0097 + Step 400: loss=0.0097 + Step 500: loss=0.0096 + + model.w = Array([[0.80577844]], dtype=float32) + model.b = Array([0.39571917], dtype=float32) + model.count = Array(507, dtype=int32, weak_type=True) + + + .. tab-item:: Linen + :sync: Linen + + .. code-block:: console + + Step 0: loss=0.3775 + Step 100: loss=0.0105 + Step 200: loss=0.0098 + Step 300: loss=0.0100 + Step 400: loss=0.0098 + Step 500: loss=0.0098 + + bounded_model.w.value = Array([[0.7771205]], dtype=float32) + bounded_model.b.value = Array([0.41074392], dtype=float32) + bounded_model.count.value = Array(507, dtype=int32, weak_type=True) + + + +NNX is friendly for advanced users +---------------------------------- -## NNX is friendly for advanced users * Example of manual scan over layer -```python -class Block(nnx.Module): - - def __init__(self, dim: int, *, ctx: nnx.Context): - self.linear = nnx.Linear(dim, dim, ctx=ctx) - self.bn = nnx.BatchNorm(dim, ctx=ctx) - self.dropout = nnx.Dropout(0.5) - - def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: - x = self.linear(x) - x = self.bn(x, ctx=ctx) - x = self.dropout(x, ctx=ctx) - x = jax.nn.gelu(x) - return x -``` +.. codediff:: + :title_left: NNX + :title_right: Linen + :sync: -```python -from functools import partial + class Block(nnx.Module): + def __init__(self, dim: int, *, ctx: nnx.Context): + self.linear = nnx.Linear(dim, dim, ctx=ctx) + self.bn = nnx.BatchNorm(dim, ctx=ctx) + self.dropout = nnx.Dropout(0.5) -class ScanMLP(nnx.Module): + def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: + x = self.linear(x) + x = self.bn(x, ctx=ctx) + x = self.dropout(x, ctx=ctx) + x = jax.nn.gelu(x) + return x - def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context): - self.n_layers = n_layers - keys, ctxdef = ctx.partition() - params_key = jax.random.split(keys["params"], n_layers) + --- - @partial(jax.vmap, out_axes=(0, None, None)) - def create_block(params_key): - ctx = ctxdef.merge({"params": params_key}) - (params, batch_stats), moduledef = Block(dim, ctx=ctx).partition( - nnx.Param, nnx.BatchStat - ) - return params, batch_stats, moduledef + class Block(nn.Module): + dim: int - params, batch_stats, moduledef = create_block(params_key) - self.layers = moduledef.merge(params, batch_stats) + def setup(self): + self.linear = nn.Dense(dim) + self.bn = nn.BatchNorm() + self.dropout = nn.Dropout(0.5) - def __call__(self, x: jax.Array, *, ctx: nnx.Context): - keys, ctxdef = ctx.partition() - dropout_key = jax.random.split(keys["dropout"], self.n_layers) - (params, batch_stats), moduledef = self.layers.partition( - nnx.Param, nnx.BatchStat - ) + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = jax.nn.gelu(x) + return x - def scan_fn( - carry: tuple[jax.Array, nnx.State], inputs: tuple[nnx.State, jax.Array] - ): - (x, batch_stats), (params, dropout_key) = carry, inputs - module = moduledef.merge(params, batch_stats) - x = module(x, ctx=ctxdef.merge({"dropout": dropout_key})) - params, _ = module.partition(nnx.Param) - return (x, batch_stats), params - - (x, batch_stats), params = jax.lax.scan( - scan_fn, (x, batch_stats), (params, dropout_key) - ) - self.layers.update_state((params, batch_stats)) - return x -``` +.. codediff:: + :title_left: NNX + :title_right: Linen + :sync: + from functools import partial + + class ScanMLP(nnx.Module): + + def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context): + self.n_layers = n_layers + # partition Context and split the `params` key + keys, ctxdef = ctx.partition() + params_key = jax.random.split(keys["params"], n_layers) + + @partial(jax.vmap, out_axes=(0, None, None)) + def create_block(params_key): + # merge back Context using the sliced `params` key + ctx = ctxdef.merge({"params": params_key}) + # create Block instance and return its partition + (params, batch_stats), moduledef = Block(dim, ctx=ctx).partition( + nnx.Param, nnx.BatchStat + ) + return params, batch_stats, moduledef + + params, batch_stats, moduledef = create_block(params_key) + self.layers = moduledef.merge(params, batch_stats) + + def __call__(self, x: jax.Array, *, ctx: nnx.Context): + keys, ctxdef = ctx.partition() + dropout_key = jax.random.split(keys["dropout"], self.n_layers) + (params, batch_stats), moduledef = self.layers.partition( + nnx.Param, nnx.BatchStat + ) -## Parameter surgery is intuitive -* Simple parameter surgery example + def scan_fn( + carry: tuple[jax.Array, nnx.State], inputs: tuple[nnx.State, jax.Array] + ): + (x, batch_stats), (params, dropout_key) = carry, inputs + module = moduledef.merge(params, batch_stats) + x = module(x, ctx=ctxdef.merge({"dropout": dropout_key})) + params, _ = module.partition(nnx.Param) + return (x, batch_stats), params + + (x, batch_stats), params = jax.lax.scan( + scan_fn, (x, batch_stats), (params, dropout_key) + ) + self.layers.update_state((params, batch_stats)) + return x + + --- -## + # 😅 +Parameter surgery is intuitive +-------------------------------- +* Simple parameter surgery example ```python def load_pretrained_model(): diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index 481174df85..f1f7c51e40 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -604,7 +604,7 @@ def scan_apply( for axes_state, axis in zip(axes_states, options.variable_axes.values()) ) # transpose axes arg - axes_arg = tree_map_upto_left( + axes_arg = jax.tree_map( lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), node), options.in_axes, axes_arg, @@ -772,7 +772,7 @@ def scan_fn( for axes_state, axis in zip(axes_states, options.variable_axes.values()) ) # transpose axes arg - out = tree_map_upto_left( + out = jax.tree_map( lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, 0, axis), node), options.out_axes, out, @@ -1022,15 +1022,3 @@ def wrapper( return remat_apply(options, f, module, args, ctx) return wrapper # type: ignore - - -def tree_map_upto_left( - f: tp.Callable[[tp.Any, tp.Any], tp.Any], left: tp.Any, right: tp.Any -) -> tp.Any: - leaves_left, treedef = jtu.tree_flatten(left) - leaves_right = treedef.flatten_up_to(right) - - return treedef.unflatten( - f(left_leaf, right_leaf) - for left_leaf, right_leaf in zip(leaves_left, leaves_right) - )