diff --git a/README.md b/README.md index dff4424..761df44 100644 --- a/README.md +++ b/README.md @@ -1,76 +1,116 @@ # JAX Scalify: end-to-end scaled arithmetic -**JAX Scalify** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of +**JAX Scalify** is a library implementing end-to-end scale propation and scaled arithmetic, allowing easy training and inference of deep neural networks in low precision (BF16, FP16, FP8). -Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Usually, these works have focused on ad-hoc approaches around scaling of matmuls (and sometimes reduction operations). The JSA library is adopting a more systematic approach by transforming the full computational graph into a `ScaledArray` graph, i.e. every operation taking `ScaledArray` inputs and returning `ScaledArray`, where the latter is a simple datastructure: +Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Most of these works focus on ad-hoc approaches around scaling of matrix multiplications (and sometimes reduction operations). `Scalify` is adopting a more systematic approach with end-to-end scale propagation, i.e. transforming the full computational graph into a `ScaledArray` graph where every operation has `ScaledArray` inputs and returns `ScaledArray`: + ```python @dataclass class ScaledArray: + # Main data component, in low precision. data: Array + # Scale, usually scalar, in FP32 or E8M0. scale: Array - def to_array(self) -> Array: - return data * scale + def __array__(self) -> Array: + # Tensor represented as a `ScaledArray`. + return data * scale.astype(self.data.dtype) ``` -A typical JAX training loop requires just a few modifications to take advantage of `scalify`: +The main benefits of the `scalify` approach are: + +* Agnostic to neural-net model definition; +* Decoupling scaling from low-precision, reducing the computational overhead of dynamic rescaling; +* FP8 matrix multiplications and reductions as simple as a cast; +* Out-of-the-box support of FP16 (scaled) master weights and optimizer state; +* Composable with JAX ecosystem: [Flax](https://github.com/google/flax), [Optax](https://github.com/google-deepmind/optax), ... + +## Scalify training loop example + +A typical JAX training loop just requires a couple of modifications to take advantage of `scalify`. More specifically: + +* Represent input and state as `ScaledArray` using the `as_scaled_array` method (or variations of it); +* End-to-end scale propagation in `update` training method using `scalify` decorator; +* (Optionally) add `dynamic_rescale` calls to improve low-precision accuracy and stability; + + +The following (simplified) example presents how to `scalify` can be incorporated into a JAX training loop. ```python import jax_scalify as jsa -params = jsa.as_scaled_array(params) - -@jit +# Scalify transform on FWD + BWD + optimizer. +# Propagating scale in the computational graph. @jsa.scalify -def update(params, batch): - grads = grad(loss)(params, batch) - return opt_update(params, grads) - -for batch in batches: - batch = jsa.as_scaled_array(batch) - params = update(params, batch) +def update(state, data, labels): + # Forward and backward pass on the NN model. + loss, grads = + jax.grad(model)(state, data, labels) + # Optimizer applied on scaled state. + state = optimizer.apply(state, grads) + return loss, state + +# Model + optimizer state. +state = (model.init(...), optimizer.init(...)) +# Transform state to scaled array(s) +sc_state = jsa.as_scaled_array(state) + +for (data, labels) in dataset: + # If necessary (e.g. images), scale input data. + data = jsa.as_scaled_array(data) + # State update, with full scale propagation. + sc_state = update(sc_state, data, labels) + # Optional dynamic rescaling of state. + sc_state = jsa.ops.dynamic_rescale_l2(sc_state) ``` -In other words: model parameters and micro-batch are converted to `ScaledArray` objects, and the decorator `jsa.scalify` properly transforms the graph into a scaled arithmetics graph (see the [MNIST examples](./experiments/mnist/) for more details). +As presented in the code above, the model state is represented as a JAX PyTree of `ScaledArray`, propagated end-to-end through the model (forward and backward passes) as well as the optimizer. -There are multiple benefits to this systematic approach: -* The model definition is unchanged (i.e. compared to unit scaling); -* The dynamic rescaling logic can be moved to optimizer update phase, simplifying the model definition and state; -* Clean implementation, as a JAX interpreter, similarly to `grad`, `vmap`, ... -* Generalize to different quantization paradigms: `int8` per channel, `MX` block scaling, per tensor scaling; -* FP16 training is more stable? -* FP8 support out of the box? +A full collection of examples is available: +* [Scalify quickstart notebook](./examples/scalify-quickstart.ipynb): basics of `ScaledArray` and `scalify` transform; +* [MNIST FP16 training example](./experiments/mnist/mnist_classifier_from_scratch.py): adapting JAX MNIST example to `scalify`; +* [MNIST FP8 training example](./experiments/mnist/mnist_classifier_from_scratch.py): easy FP8 support in `scalify`; +* [CIFAR10 training](./experiments/mnist/cifar_training.py): `scalify` CIFAR10 training, with Optax optimizer integration; ## Installation -JSA library can be easily installed in Python virtual environnment: +JAX Scalify can be directly installed from the github repository in Python virtual environment: ```bash -git clone git@github.com:graphcore-research/jax-scalify.git -pip install -e ./ +pip install git+https://github.com/graphcore-research/jax-scalify.git@main ``` -The main dependencies are `numpy`, `jax` and `chex` libraries. -**Note:** it is compatible with [experimental JAX on IPU](https://github.com/graphcore-research/jax-experimental), which can be installed in a Graphcore Poplar Python environnment: +Alternatively, for a local development setup: ```bash -pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html -``` -Here are the common JAX libraries compatible with IPU: -```bash -pip install chex==0.1.6 flax==0.6.4 equinox==0.7.0 jaxtyping==0.2.8s +git clone git@github.com:graphcore-research/jax-scalify.git +pip install -e ./ ``` +The major dependencies are `numpy`, `jax` and `chex` libraries. + ## Documentation -* [Draft Scaled Arithmetics design document](docs/design.md); -* [Scaled operators coverage](docs/operators.md) +* [(Draft) Scaled Arithmetics design document](docs/design.md); +* [Operators coverage in `scalify`](docs/operators.md) ## Development -Running `pre-commit` and `pytest`: +Running `pre-commit` and `pytest` on the JAX Scalify repository: ```bash pip install pre-commit pre-commit run --all-files pytest -v ./tests ``` Python wheel can be built with the usual command `python -m build`. + +## Graphcore IPU support + + +JAX Scalify v0.1 is compatible with [experimental JAX on IPU](https://github.com/graphcore-research/jax-experimental), which can be installed in a Graphcore Poplar Python environnment: +```bash +pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html +``` +Here are the common JAX libraries compatible with IPU: +```bash +pip install chex==0.1.6 flax==0.6.4 equinox==0.7.0 jaxtyping==0.2.8 +``` diff --git a/docs/alpha_stable.md b/docs/alpha_stable.md deleted file mode 100644 index 5c23d62..0000000 --- a/docs/alpha_stable.md +++ /dev/null @@ -1,19 +0,0 @@ -# Alpha-stable distribution modelling - -For any $alpha\in (0, 2]$, a centered alpha stable random variable $X$ has the following characteristic function -$$exp(-|st|^\alpha)$$ -where $s$ is the scaling of $X$. If $\alpha=2$, it corresponds to a centered Gaussian distribution. - -If $X$ and $Y$ are independent alpha stable r.v., the sum $Z$ will satisfy: -$$s_z^\alpha = s_x^\alpha + s_y^\alpha$$ - - -Some questions: -* Is an alpha stable model more robust, i.e. better thanks to heavy tails representing outliers? -* $alpha=1$, i.e. Cauchy distribution has the nice aspect of being very simple! - - -## References - -* https://en.wikipedia.org/wiki/Stable_distribution -* https://www.sciencedirect.com/topics/mathematics/stable-distribution diff --git a/pyproject.toml b/pyproject.toml index 3d7c492..a88e1a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,10 +49,10 @@ target-version = ['py38', 'py39', 'py310'] [tool.isort] line_length = 120 -known_first_party = "tessellate_ipu" +known_first_party = "jax_scalify" [tool.mypy] -python_version = "3.8" +python_version = "3.10" plugins = ["numpy.typing.mypy_plugin"] # Config heavily inspired by Pydantic! show_error_codes = true