Skip to content

Commit

Permalink
[nnx] add tabulate
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 9, 2025
1 parent e2134af commit 4352578
Show file tree
Hide file tree
Showing 21 changed files with 1,200 additions and 332 deletions.
206 changes: 100 additions & 106 deletions docs_nnx/mnist_tutorial.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs_nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class CNN(nnx.Module):
# Instantiate the model.
model = CNN(rngs=nnx.Rngs(0))
# Visualize it.
nnx.display(model)
print(nnx.tabulate(model))
```

### Run the model
Expand All @@ -112,7 +112,7 @@ Let's put the CNN model to the test! Here, you’ll perform a forward pass with
import jax.numpy as jnp # JAX NumPy
y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)
y
```

## 4. Create the optimizer and define some metrics
Expand Down
129 changes: 88 additions & 41 deletions docs_nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

15 changes: 2 additions & 13 deletions docs_nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,7 @@ jupytext:

Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.

In this guide you will learn about:

- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.
- Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).
- Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers.
- Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.
- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.
- [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.
- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.
- [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).
- [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`
- Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.
To begin, install Flax with `pip` and import necessary dependencies:

## Setup

Expand Down Expand Up @@ -106,7 +95,7 @@ to handle them, as demonstrated in later sections of this guide.

Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.

The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:
The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer.

```{code-cell} ipython3
class MLP(nnx.Module):
Expand Down
18 changes: 11 additions & 7 deletions flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
LogicalNames,
)

try:
from IPython import get_ipython

in_ipython = get_ipython() is not None
except ImportError:
in_ipython = False


class _ValueRepresentation(ABC):
"""A class that represents a value in the summary table."""
Expand Down Expand Up @@ -242,11 +249,6 @@ def tabulate(
Total Parameters: 50 (200 B)
**Note**: rows order in the table does not represent execution order,
instead it aligns with the order of keys in `variables` which are sorted
alphabetically.
**Note**: `vjp_flops` returns `0` if the module is not differentiable.
Args:
Expand All @@ -267,7 +269,9 @@ def tabulate(
mutable.
console_kwargs: An optional dictionary with additional keyword arguments
that are passed to `rich.console.Console` when rendering the table.
Default arguments are `{'force_terminal': True, 'force_jupyter': False}`.
Default arguments are ``'force_terminal': True``, and ``'force_jupyter'``
is set to ``True`` if the code is running in a Jupyter notebook, otherwise
it is set to ``False``.
table_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table` constructor.
column_kwargs: An optional dictionary with additional keyword arguments that
Expand Down Expand Up @@ -564,7 +568,7 @@ def _render_table(
non_params_cols: list[str],
) -> str:
"""A function that renders a Table to a string representation using rich."""
console_kwargs = {'force_terminal': True, 'force_jupyter': False}
console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython}
if console_extras is not None:
console_kwargs.update(console_extras)

Expand Down
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
from .summary import tabulate as tabulate
from . import traversals as traversals
4 changes: 3 additions & 1 deletion flax/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def to_predicate(filter: Filter) -> Predicate:
else:
raise TypeError(f'Invalid collection filter: {filter:!r}. ')

def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
def filters_to_predicates(
filters: tp.Sequence[Filter],
) -> tuple[Predicate, ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
Expand Down
12 changes: 4 additions & 8 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import typing_extensions as tpe

from flax.nnx import filterlib, reprlib
from flax.nnx import filterlib, reprlib, visualization
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
Expand Down Expand Up @@ -248,8 +248,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('index', self.index)

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
return visualization.render_object_constructor(
object_type=type(self),
attributes={'type': self.type, 'index': self.index},
path=path,
Expand All @@ -272,9 +271,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]

return treescope.repr_lib.render_object_constructor(
return visualization.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
Expand Down Expand Up @@ -353,8 +350,7 @@ def __nnx_repr__(self):
)

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
return visualization.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
Expand Down
17 changes: 0 additions & 17 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,23 +403,6 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
flatten_func=partial(_module_flatten, with_keys=False),
)

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
color=treescope.formatting_util.color_from_string(
type(self).__qualname__
)
)

# -------------------------
# Pytree Definition
# -------------------------
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flax.nnx.module import Module, first_from


@dataclasses.dataclass
@dataclasses.dataclass(repr=False)
class Dropout(Module):
"""Create a dropout layer.
Expand Down
Loading

0 comments on commit 4352578

Please sign in to comment.