Skip to content

Commit

Permalink
Improve JAX Scalify main README.md (#112)
Browse files Browse the repository at this point in the history
Additional explanations + proper examples.
  • Loading branch information
balancap committed Jun 17, 2024
1 parent 2157679 commit b3ed77c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 57 deletions.
112 changes: 76 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,76 +1,116 @@
# JAX Scalify: end-to-end scaled arithmetic

**JAX Scalify** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
**JAX Scalify** is a library implementing end-to-end scale propation and scaled arithmetic, allowing easy training and inference of
deep neural networks in low precision (BF16, FP16, FP8).

Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Usually, these works have focused on ad-hoc approaches around scaling of matmuls (and sometimes reduction operations). The JSA library is adopting a more systematic approach by transforming the full computational graph into a `ScaledArray` graph, i.e. every operation taking `ScaledArray` inputs and returning `ScaledArray`, where the latter is a simple datastructure:
Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Most of these works focus on ad-hoc approaches around scaling of matrix multiplications (and sometimes reduction operations). `Scalify` is adopting a more systematic approach with end-to-end scale propagation, i.e. transforming the full computational graph into a `ScaledArray` graph where every operation has `ScaledArray` inputs and returns `ScaledArray`:

```python
@dataclass
class ScaledArray:
# Main data component, in low precision.
data: Array
# Scale, usually scalar, in FP32 or E8M0.
scale: Array

def to_array(self) -> Array:
return data * scale
def __array__(self) -> Array:
# Tensor represented as a `ScaledArray`.
return data * scale.astype(self.data.dtype)
```

A typical JAX training loop requires just a few modifications to take advantage of `scalify`:
The main benefits of the `scalify` approach are:

* Agnostic to neural-net model definition;
* Decoupling scaling from low-precision, reducing the computational overhead of dynamic rescaling;
* FP8 matrix multiplications and reductions as simple as a cast;
* Out-of-the-box support of FP16 (scaled) master weights and optimizer state;
* Composable with JAX ecosystem: [Flax](https://github.com/google/flax), [Optax](https://github.com/google-deepmind/optax), ...

## Scalify training loop example

A typical JAX training loop just requires a couple of modifications to take advantage of `scalify`. More specifically:

* Represent input and state as `ScaledArray` using the `as_scaled_array` method (or variations of it);
* End-to-end scale propagation in `update` training method using `scalify` decorator;
* (Optionally) add `dynamic_rescale` calls to improve low-precision accuracy and stability;


The following (simplified) example presents how to `scalify` can be incorporated into a JAX training loop.
```python
import jax_scalify as jsa

params = jsa.as_scaled_array(params)

@jit
# Scalify transform on FWD + BWD + optimizer.
# Propagating scale in the computational graph.
@jsa.scalify
def update(params, batch):
grads = grad(loss)(params, batch)
return opt_update(params, grads)

for batch in batches:
batch = jsa.as_scaled_array(batch)
params = update(params, batch)
def update(state, data, labels):
# Forward and backward pass on the NN model.
loss, grads =
jax.grad(model)(state, data, labels)
# Optimizer applied on scaled state.
state = optimizer.apply(state, grads)
return loss, state

# Model + optimizer state.
state = (model.init(...), optimizer.init(...))
# Transform state to scaled array(s)
sc_state = jsa.as_scaled_array(state)

for (data, labels) in dataset:
# If necessary (e.g. images), scale input data.
data = jsa.as_scaled_array(data)
# State update, with full scale propagation.
sc_state = update(sc_state, data, labels)
# Optional dynamic rescaling of state.
sc_state = jsa.ops.dynamic_rescale_l2(sc_state)
```
In other words: model parameters and micro-batch are converted to `ScaledArray` objects, and the decorator `jsa.scalify` properly transforms the graph into a scaled arithmetics graph (see the [MNIST examples](./experiments/mnist/) for more details).
As presented in the code above, the model state is represented as a JAX PyTree of `ScaledArray`, propagated end-to-end through the model (forward and backward passes) as well as the optimizer.

There are multiple benefits to this systematic approach:

* The model definition is unchanged (i.e. compared to unit scaling);
* The dynamic rescaling logic can be moved to optimizer update phase, simplifying the model definition and state;
* Clean implementation, as a JAX interpreter, similarly to `grad`, `vmap`, ...
* Generalize to different quantization paradigms: `int8` per channel, `MX` block scaling, per tensor scaling;
* FP16 training is more stable?
* FP8 support out of the box?
A full collection of examples is available:
* [Scalify quickstart notebook](./examples/scalify-quickstart.ipynb): basics of `ScaledArray` and `scalify` transform;
* [MNIST FP16 training example](./experiments/mnist/mnist_classifier_from_scratch.py): adapting JAX MNIST example to `scalify`;
* [MNIST FP8 training example](./experiments/mnist/mnist_classifier_from_scratch.py): easy FP8 support in `scalify`;
* [CIFAR10 training](./experiments/mnist/cifar_training.py): `scalify` CIFAR10 training, with Optax optimizer integration;


## Installation

JSA library can be easily installed in Python virtual environnment:
JAX Scalify can be directly installed from the github repository in Python virtual environment:
```bash
git clone git@github.com:graphcore-research/jax-scalify.git
pip install -e ./
pip install git+https://github.com/graphcore-research/jax-scalify.git@main
```
The main dependencies are `numpy`, `jax` and `chex` libraries.

**Note:** it is compatible with [experimental JAX on IPU](https://github.com/graphcore-research/jax-experimental), which can be installed in a Graphcore Poplar Python environnment:
Alternatively, for a local development setup:
```bash
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
```
Here are the common JAX libraries compatible with IPU:
```bash
pip install chex==0.1.6 flax==0.6.4 equinox==0.7.0 jaxtyping==0.2.8s
git clone git@github.com:graphcore-research/jax-scalify.git
pip install -e ./
```
The major dependencies are `numpy`, `jax` and `chex` libraries.


## Documentation

* [Draft Scaled Arithmetics design document](docs/design.md);
* [Scaled operators coverage](docs/operators.md)
* [(Draft) Scaled Arithmetics design document](docs/design.md);
* [Operators coverage in `scalify`](docs/operators.md)

## Development

Running `pre-commit` and `pytest`:
Running `pre-commit` and `pytest` on the JAX Scalify repository:
```bash
pip install pre-commit
pre-commit run --all-files
pytest -v ./tests
```
Python wheel can be built with the usual command `python -m build`.

## Graphcore IPU support


JAX Scalify v0.1 is compatible with [experimental JAX on IPU](https://github.com/graphcore-research/jax-experimental), which can be installed in a Graphcore Poplar Python environnment:
```bash
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
```
Here are the common JAX libraries compatible with IPU:
```bash
pip install chex==0.1.6 flax==0.6.4 equinox==0.7.0 jaxtyping==0.2.8
```
19 changes: 0 additions & 19 deletions docs/alpha_stable.md

This file was deleted.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ target-version = ['py38', 'py39', 'py310']

[tool.isort]
line_length = 120
known_first_party = "tessellate_ipu"
known_first_party = "jax_scalify"

[tool.mypy]
python_version = "3.8"
python_version = "3.10"
plugins = ["numpy.typing.mypy_plugin"]
# Config heavily inspired by Pydantic!
show_error_codes = true
Expand Down

0 comments on commit b3ed77c

Please sign in to comment.