Skip to content

Commit

Permalink
add flax examples
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Aug 7, 2023
1 parent d680016 commit 44f5e72
Showing 1 changed file with 124 additions and 49 deletions.
173 changes: 124 additions & 49 deletions docs/experimental/nnx/why_nnx.rst
Original file line number Diff line number Diff line change
@@ -1,75 +1,150 @@
# Why NNX?
Why NNX?
========

Flax Linen is currently the most flexible and powerful way to write neural networks in JAX. The main features that have made it so popular are `State collections <https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections>`__, `RNG handling <https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences>`__, `Collection-aware lifted transformations <https://flax.readthedocs.io/en/latest/developer_notes/lift.html>`__, and `Leaf metadata <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning>`__.

Flax Linen is currently the most flexible and powerful way to write neural networks in JAX. The main features that have made it so popular are [State collections](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections), [RNG handling](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences), [Collection-aware lifted transformations](https://flax.readthedocs.io/en/latest/developer_notes/lift.html), and [Leaf metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning).

However, Linen's power has come at a cost:
* The `init` and `apply` APIs require a learning curve (on top of JAX's learning curve).
* The Module's dataclass and `compact` semantics drift away from regular Python semantics and have a very complex internal implementation.
* It is not very easily to integrate pre-trained models into bigger models as the Module structure is separate from the `params` structure.

* The ``init`` and ``apply`` APIs require a learning curve (on top of JAX's learning curve).
* The Module's dataclass and ``compact`` semantics drift away from regular Python semantics and have a very complex internal implementation.
* It is not very easily to integrate pre-trained models into bigger models as the Module structure is separate from the ``params`` structure.
* The implementation of the lifted transformations is very complex.

Flax NNX is an attempt to keep the features that made Linen great while simplifying the API and making it more Pythonic.


## NNX is Pythonic
NNX is Pythonic
---------------

* Example of building a Module

```python
from flax.experimental import nnx
import jax
import jax.numpy as jnp

.. codediff::
:title_left: NNX
:title_right: Linen
:sync:

class Count(nnx.Variable): pass
from flax.experimental import nnx
import jax
import jax.numpy as jnp


class Linear(nnx.Module):
class Count(nnx.Variable): pass

def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
self.din = din
self.dout = dout
key = ctx.make_rng("params")
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(0) # track the number of calls
class Linear(nnx.Module):

def __call__(self, x) -> jax.Array:
self.count += 1
return x @ self.w + self.b
def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
self.din, self.dout = din, dout
key = ctx.make_rng("params")
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(0) # track the number of calls
def __call__(self, x) -> jax.Array:
self.count += 1
return x @ self.w + self.b

model = Linear(din=5, dout=2, ctx=nnx.context(0))
x = jnp.ones((1, 5))
y = model(x)
```

```python
print(f"{model.count = }")
print(f"{model.w = }")
print(f"{model.b = }")
print(f"{model = }")
```
```
model.count = 1
model.w = Array([[0.0779959 , 0.8061936 ],
[0.05617034, 0.55959475],
[0.3948189 , 0.5856023 ],
[0.82162833, 0.27394366],
[0.07696676, 0.8982161 ]], dtype=float32)
model.b = Array([0., 0.], dtype=float32)
model = Linear(
din=5,
dout=2
)
```
model = Linear(din=5, dout=2, ctx=nnx.context(0))
x = jnp.ones((1, 5))
y = model(x)

<!-- #region -->
---

import flax.linen as nn
import jax
import jax.numpy as jnp

## NNX's eager mode is simple
* Example of training in eager mode

<!-- #endregion -->
class Linear(nn.Module):
din: int
dout: int

def setup(self):
din, dout = self.din, self.dout
key = self.make_rng("params") if self.is_initializing() else None
self.w = self.variable("params", "w", jax.random.uniform, key, (din, dout))
self.b = self.variable("params", "b", jnp.zeros, (dout,))
self.count = self.variable("counts", "count", lambda: 0)

def __call__(self, x) -> jax.Array:
self.count.value += 1
return x @ self.w.value + self.b.value

model = Linear(din=5, dout=2)
x = jnp.ones((1, 5))
variables = model.init(jax.random.PRNGKey(0), x)
params, counts = variables["params"], variables["counts"]
y, updates = model.apply(
{"params": params, "counts": counts}, x, mutable=["counts"]
)
counts = updates["counts"]

.. codediff::
:title_left: NNX
:title_right: Linen
:sync:

print(f"{model.count = }")
print(f"{model.w = }")
print(f"{model.b = }")
print(f"{model = }")

---

bounded_model = model.bind({"params": params, "counts": counts})

print(f"{bounded_model.count.value = }")
print(f"{bounded_model.w.value = }")
print(f"{bounded_model.b.value = }")
print(f"{bounded_model = }")

**Output:**

.. tab-set::

.. tab-item:: NNX
:sync: NNX

.. code-block:: python
model.count = 1
model.w = Array([[0.0779959 , 0.8061936 ],
[0.05617034, 0.55959475],
[0.3948189 , 0.5856023 ],
[0.82162833, 0.27394366],
[0.07696676, 0.8982161 ]], dtype=float32)
model.b = Array([0., 0.], dtype=float32)
model = Linear(
din=5,
dout=2
)
.. tab-item:: Linen
:sync: Linen

.. code-block:: python
bounded_model.count.value = 1
bounded_model.w.value = Array([[0.76684463, 0.51083136],
[0.3042251 , 0.77967715],
[0.20216525, 0.03781104],
[0.68387973, 0.9263613 ],
[0.47634053, 0.7418159 ]], dtype=float32)
bounded_model.b.value = Array([0., 0.], dtype=float32)
bounded_model = Linear(
# attributes
din = 5
dout = 2
)
NNX is friendly for beginners
-----------------------------

* Example of training in eager mode

```python
import numpy as np
Expand Down

0 comments on commit 44f5e72

Please sign in to comment.