Skip to content

Commit

Permalink
Update README.md with PyPi install instructions and paper link. (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jul 17, 2024
1 parent 257760a commit 4276eea
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 337 deletions.
61 changes: 32 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# JAX Scalify: end-to-end scaled arithmetic

**JAX Scalify** is a library implementing end-to-end scale propation and scaled arithmetic, allowing easy training and inference of
[![tests](https://github.com/graphcore-research/jax-scalify/actions/workflows/tests.yaml/badge.svg)](https://github.com/graphcore-research/jax-scalify/actions/workflows/tests-public.yaml)
![PyPI version](https://img.shields.io/pypi/v/jax-scalify)
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/graphcore-research/jax-scalify/blob/main/LICENSE)
[![GitHub Repo stars](https://img.shields.io/github/stars/graphcore-research/jax-scalify)](https://github.com/graphcore-research/jax-scalify/stargazers)
<!-- [![codecov](https://codecov.io/gh/jax-scalify/branch/main/graph/badge.svg?token=bHOkKY5Fze)](https://codecov.io/gh/jax-scalify) -->

[**Installation**](#installation)
| [**Quickstart**](#quickstart)
| [**Documentation**](#documentation)

**📣 Scalify** has been accepted to [**ICML 2024 workshop WANT**](https://openreview.net/forum?id=4IWCHWlb6K)! 📣

**JAX Scalify** is a library implementing end-to-end scale propagation and scaled arithmetic, allowing easy training and inference of
deep neural networks in low precision (BF16, FP16, FP8).

Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Most of these works focus on ad-hoc approaches around scaling of matrix multiplications (and sometimes reduction operations). `Scalify` is adopting a more systematic approach with end-to-end scale propagation, i.e. transforming the full computational graph into a `ScaledArray` graph where every operation has `ScaledArray` inputs and returns `ScaledArray`:
Expand All @@ -26,7 +38,20 @@ The main benefits of the `scalify` approach are:
* Out-of-the-box support of FP16 (scaled) master weights and optimizer state;
* Composable with JAX ecosystem: [Flax](https://github.com/google/flax), [Optax](https://github.com/google-deepmind/optax), ...

## Scalify training loop example
## Installation

JAX Scalify can be directly installed from PyPi:
```bash
pip install jax-scalify
```
Please follow [JAX documentation](https://github.com/google/jax/blob/main/README.md#installation) for a proper JAX installation on GPU/TPU.

The latest version of JAX Scalify is available directly from Github:
```bash
pip install git+https://github.com/graphcore-research/jax-scalify.git
```

## Quickstart

A typical JAX training loop just requires a couple of modifications to take advantage of `scalify`. More specifically:

Expand Down Expand Up @@ -72,28 +97,18 @@ A full collection of examples is available:
* [MNIST FP8 training example](./examples/mnist/mnist_classifier_from_scratch_fp8.py): easy FP8 support in `scalify`;
* [MNIST Flax example](./examples/mnist/mnist_classifier_mlp_flax.py): `scalify` Flax training, with Optax optimizer integration;

## Documentation

## Installation
* [**Scalify ICML 2024 workshop WANT paper**](https://openreview.net/forum?id=4IWCHWlb6K)
* [Operators coverage in JAX `scalify`](docs/operators.md)

JAX Scalify can be directly installed from the github repository in Python virtual environment:
```bash
pip install git+https://github.com/graphcore-research/jax-scalify.git@main
```
## Development

Alternatively, for a local development setup:
For a local development setup, we recommend an interactive install:
```bash
git clone git@github.com:graphcore-research/jax-scalify.git
pip install -e ./
```
The major dependencies are `numpy`, `jax` and `chex` libraries.


## Documentation

* [(Draft) Scaled Arithmetics design document](docs/design.md);
* [Operators coverage in `scalify`](docs/operators.md)

## Development

Running `pre-commit` and `pytest` on the JAX Scalify repository:
```bash
Expand All @@ -102,15 +117,3 @@ pre-commit run --all-files
pytest -v ./tests
```
Python wheel can be built with the usual command `python -m build`.

## Graphcore IPU support


JAX Scalify v0.1 is compatible with [experimental JAX on IPU](https://github.com/graphcore-research/jax-experimental), which can be installed in a Graphcore Poplar Python environnment:
```bash
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
```
Here are the common JAX libraries compatible with IPU:
```bash
pip install chex==0.1.6 flax==0.6.4 equinox==0.7.0 jaxtyping==0.2.8
```
Loading

0 comments on commit 4276eea

Please sign in to comment.