Skip to content

Commit

Permalink
Renaming library jax_scalify and decorator scalify. (#109)
Browse files Browse the repository at this point in the history
Aligning naming & branding, and getting rid of `autoscale`.
  • Loading branch information
balancap authored Jun 13, 2024
1 parent 8d94ca0 commit faf1b02
Show file tree
Hide file tree
Showing 41 changed files with 171 additions and 186 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# JAX Scaled Arithmetics
# JAX Scalify: end-to-end scaled Arithmetics

**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
**JAX Scalify** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
deep neural networks in low precision (BF16, FP16, FP8).

Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Usually, these works have focused on ad-hoc approaches around scaling of matmuls (and sometimes reduction operations). The JSA library is adopting a more systematic approach by transforming the full computational graph into a `ScaledArray` graph, i.e. every operation taking `ScaledArray` inputs and returning `ScaledArray`, where the latter is a simple datastructure:
Expand All @@ -14,14 +14,14 @@ class ScaledArray:
return data * scale
```

A typical JAX training loop requires just a few modifications to take advantage of `autoscale`:
A typical JAX training loop requires just a few modifications to take advantage of `scalify`:
```python
import jax_scaled_arithmetics as jsa
import jax_scalify as jsa

params = jsa.as_scaled_array(params)

@jit
@jsa.autoscale
@jsa.scalify
def update(params, batch):
grads = grad(loss)(params, batch)
return opt_update(params, grads)
Expand All @@ -30,7 +30,7 @@ for batch in batches:
batch = jsa.as_scaled_array(batch)
params = update(params, batch)
```
In other words: model parameters and micro-batch are converted to `ScaledArray` objects, and the decorator `jsa.autoscale` properly transforms the graph into a scaled arithmetics graph (see the [MNIST examples](./experiments/mnist/) for more details).
In other words: model parameters and micro-batch are converted to `ScaledArray` objects, and the decorator `jsa.scalify` properly transforms the graph into a scaled arithmetics graph (see the [MNIST examples](./experiments/mnist/) for more details).

There are multiple benefits to this systematic approach:

Expand All @@ -46,7 +46,7 @@ There are multiple benefits to this systematic approach:

JSA library can be easily installed in Python virtual environnment:
```bash
git clone git@github.com:graphcore-research/jax-scaled-arithmetics.git
git clone git@github.com:graphcore-research/jax-scalify.git
pip install -e ./
```
The main dependencies are `numpy`, `jax` and `chex` libraries.
Expand Down
4 changes: 2 additions & 2 deletions docs/design.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<!-- pandoc ./docs/stable-scaled-arithmetics.md --pdf-engine=xelatex -o ./docs/stable-scaled-arithmetics.pdf -V geometry:margin=2cm -->

# AutoScale: stable scaled arithmetics
# Scalify: stable scaled arithmetics

## Introduction

Expand Down Expand Up @@ -63,7 +63,7 @@ Summary of how it would compare to existing methods (credits to @thecharlieblake
| ~ perfect scale at initialisation | We can use the "unit scaling rules" to achieve approximately unit variance for the first fwd & bwd pass. Without this, our tensors immediately require re-scaling. Starting in the ideal range may mean few re-scalings are required. | Apply appropriate scaling factors to the output of our operations (as determined in the unit scaling work). | -->


## AutoScale: JAX scaled arithmetics
## JAX Scalify: end-to-end scaled arithmetics

We focus here on a JAX implementation, but it should be possible to adopt a similar approach in Pytorch (with tracing or dynamo?). Modern ML frameworks expose their IR at the Python level, allowing users to perform complex transforms on the computational graph without any modification to the C++ backend (XLA, ...).

Expand Down
2 changes: 1 addition & 1 deletion docs/operators.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# JAX Scaled Operators coverage

Summary of JAX LAX operators supported in `autoscale` graph transformation.
Summary of JAX LAX operators supported in `scalify` graph transformation.

## [JAX LAX operations](https://jax.readthedocs.io/en/latest/jax.lax.html)

Expand Down
12 changes: 6 additions & 6 deletions examples/autoscale-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"source": [
"import numpy as np\n",
"import jax\n",
"import jax_scaled_arithmetics as jsa"
"import jax_scalify as jsa"
]
},
{
Expand All @@ -38,8 +38,8 @@
"metadata": {},
"outputs": [],
"source": [
"# `autoscale` interpreter is tracing the graph, adding scale propagation where necessary.\n",
"@jsa.autoscale\n",
"# `scalify` interpreter is tracing the graph, adding scale propagation where necessary.\n",
"@jsa.scalify\n",
"def fn(a, b):\n",
" return a + b"
]
Expand Down Expand Up @@ -127,13 +127,13 @@
}
],
"source": [
"# Running `fn` on scaled arrays triggers `autoscale` graph transformation\n",
"# Running `fn` on scaled arrays triggers `scalify` graph transformation\n",
"sout = fn(sa, sb)\n",
"# NOTE: by default, scale propagation is using power-of-2.\n",
"print(\"SCALED OUTPUT:\", sout)\n",
"\n",
"# To choose a different scale rounding:\n",
"with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n",
"with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n",
" print(\"No scale rounding:\", fn(sa, sb))"
]
},
Expand Down Expand Up @@ -228,7 +228,7 @@
"# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n",
"sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(1))\n",
"\n",
"@jsa.autoscale\n",
"@jsa.scalify\n",
"def cast_fn(v):\n",
" return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n",
"\n",
Expand Down
14 changes: 7 additions & 7 deletions experiments/mnist/cifar_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy.random as npr
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa
import jax_scalify as jsa


def logsumexp(a, axis=None, keepdims=False):
Expand Down Expand Up @@ -88,8 +88,8 @@ def accuracy(params, batch):
if __name__ == "__main__":
width = 2048
lr = 1e-4
use_autoscale = True
autoscale = jsa.autoscale if use_autoscale else lambda f: f
use_scalify = True
scalify = jsa.scalify if use_scalify else lambda f: f

layer_sizes = [3072, width, width, 10]
param_scale = 1.0
Expand All @@ -116,12 +116,12 @@ def data_stream():
batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
if use_autoscale:
if use_scalify:
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@autoscale
@scalify
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
Expand All @@ -131,11 +131,11 @@ def update(params, batch):
for _ in range(num_batches):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
if use_autoscale:
if use_scalify:
batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale))
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)

epoch_time = time.time() - start_time
Expand Down
4 changes: 2 additions & 2 deletions experiments/mnist/flax_example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from flax.metrics import tensorboard
from flax.training import train_state

import jax_scaled_arithmetics as jsa
import jax_scalify as jsa


class CNN(nn.Module):
Expand Down Expand Up @@ -75,7 +75,7 @@ def update_model(state, grads):


@jax.jit
@jsa.autoscale
@jsa.scalify
def apply_and_update_model(state, batch_images, batch_labels):
# Jitting together forward + backward + update.
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
Expand Down
4 changes: 2 additions & 2 deletions experiments/mnist/mnist_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from jax.example_libraries import optimizers, stax
from jax.example_libraries.stax import Dense, LogSoftmax, Relu

import jax_scaled_arithmetics as jsa
import jax_scalify as jsa


def loss(params, batch):
Expand Down Expand Up @@ -75,7 +75,7 @@ def data_stream():
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
@jsa.autoscale
@jsa.scalify
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
Expand Down
6 changes: 3 additions & 3 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy.random as npr
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa
import jax_scalify as jsa

# from jax.scipy.special import logsumexp

Expand Down Expand Up @@ -107,7 +107,7 @@ def data_stream():
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@jsa.autoscale
@jsa.scalify
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
Expand All @@ -120,7 +120,7 @@ def update(params, batch):
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)

epoch_time = time.time() - start_time
Expand Down
6 changes: 3 additions & 3 deletions experiments/mnist/mnist_classifier_from_scratch_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy.random as npr
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa
import jax_scalify as jsa

# from functools import partial

Expand Down Expand Up @@ -134,7 +134,7 @@ def data_stream():
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@jsa.autoscale
@jsa.scalify
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
Expand All @@ -147,7 +147,7 @@ def update(params, batch):
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)

epoch_time = time.time() - start_time
Expand Down
14 changes: 7 additions & 7 deletions experiments/mnist/optax_cifar_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import optax
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa
import jax_scalify as jsa


def logsumexp(a, axis=None, keepdims=False):
Expand Down Expand Up @@ -92,9 +92,9 @@ def accuracy(params, batch):
if __name__ == "__main__":
width = 256
lr = 1e-3
use_autoscale = False
use_scalify = False
training_dtype = np.float32
autoscale = jsa.autoscale if use_autoscale else lambda f: f
scalify = jsa.scalify if use_scalify else lambda f: f

layer_sizes = [3072, width, width, 10]
param_scale = 1.0
Expand Down Expand Up @@ -123,13 +123,13 @@ def data_stream():
optimizer = optax.adam(learning_rate=lr, eps=1e-5)
opt_state = optimizer.init(params)

if use_autoscale:
if use_scalify:
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))

params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@autoscale
@scalify
def update(params, batch, opt_state):
grads = grad(loss)(params, batch)
updates, opt_state = optimizer.update(grads, opt_state)
Expand All @@ -141,11 +141,11 @@ def update(params, batch, opt_state):
for _ in range(num_batches):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
if use_autoscale:
if use_scalify:
batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale))
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params, opt_state = update(params, batch, opt_state)

epoch_time = time.time() - start_time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from . import core, lax, ops
from ._version import __version__
from .core import ( # noqa: F401
AutoScaleConfig,
Pow2RoundMode,
ScaledArray,
ScalifyConfig,
as_scaled_array,
asarray,
autoscale,
debug_callback,
scaled_array,
scalify,
)
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
)
from .debug import debug_callback # noqa: F401
from .interpreters import ( # noqa: F401
AutoScaleConfig,
ScaledPrimitiveType,
autoscale,
ScalifyConfig,
find_registered_scaled_op,
get_autoscale_config,
get_scalify_config,
register_scaled_lax_op,
register_scaled_op,
scalify,
)
from .pow2 import Pow2RoundMode, pow2_decompose, pow2_round, pow2_round_down, pow2_round_up # noqa: F401
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit faf1b02

Please sign in to comment.