Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic interpeter for scaled ops and scaled arrays #8

Merged
merged 12 commits into from
Nov 8, 2023
2 changes: 2 additions & 0 deletions jax_scaled_arithmetics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from . import lax
from ._version import __version__
from .core import ScaledArray, autoscale # noqa: F401
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import ScaledArray # noqa: F401
from .interpreters import autoscale, register_scaled_op # noqa: F401
4 changes: 4 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,7 @@ def to_array(self, dtype: DTypeLike = None) -> GenericArray:
def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]:
"""Numpy array interface support."""
return np.asarray(self.to_array(dtype))

@property
def aval(self):
return self.data * self.scale
59 changes: 59 additions & 0 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from functools import wraps
from typing import Dict

import jax
from jax import core
from jax._src.util import safe_map

from ..core import ScaledArray

_scaled_ops_registry = {}


def register_scaled_op(lax_func, scaled_func):
_scaled_ops_registry[lax_func] = scaled_func


def autoscale(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
aval_args = safe_map(lambda x: x.aval, args)
# get jaxpr of unscaled graph
closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs)
# convert to scaled graph
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out

return wrapped


def autoscale_jaxpr(jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}

def read(var):
if type(var) is core.Literal:
return var.val
return env[var]

def write(var, val):
env[var] = val

safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

for eqn in jaxpr.eqns:
invals = safe_map(read, eqn.invars)
if eqn.primitive not in _scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
outvals = _scaled_ops_registry[eqn.primitive](*invals)
if not eqn.primitive.multiple_results:
outvals = [outvals]
safe_map(write, eqn.outvars, outvals)

outvals = safe_map(read, jaxpr.outvars)
if len(outvals) == 1:
return outvals[0]
else:
return outvals
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .scaled_ops import * # noqa: F401, F403
15 changes: 15 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from jax import lax

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import ScaledArray


def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)


core.register_scaled_op(lax.mul_p, scaled_mul_p)

__all__ = ["scaled_mul_p"]
54 changes: 54 additions & 0 deletions tests/interpreters/test_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

import chex
import jax
import jax.numpy as jnp

from jax_scaled_arithmetics.core import ScaledArray, autoscale


class AutoScaleInterpreterTests(chex.TestCase):
def test__identity(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker: we can at some point use @chex.variants(with_jit=True, without_jit=True) to run the same unit test with jax.jit added or not. Helps having consistent checks that our logic works in eager/interpreter mode or jitting mode.

def func(x):
return x

asfunc = autoscale(func)

scale = jnp.array(1.0)
inputs = jnp.array([1.0, 2.0])
expected = jnp.array([1.0, 2.0])

scaled_inputs = ScaledArray(inputs, scale)
scaled_outputs = asfunc(scaled_inputs)

assert jnp.allclose(scaled_outputs.aval, expected)

jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr

# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray

assert jaxpr.invars[0].aval.shape == inputs.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here a comment to explain why the jaxpr has doubled amount of inputs/outputs, due to scale tensor. Not that obvious if you're not a JAX expert!

assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == expected.shape
assert jaxpr.outvars[1].aval.shape == ()

def test__mul(self):
def func(x, y):
return x * y

asfunc = autoscale(func)

x_in = jnp.array([-2.0, 2.0])
x_scale = jnp.array(0.5)
x = ScaledArray(x_in, x_scale)

y_in = jnp.array([1.5, 1.5])
y_scale = jnp.array(2.0)
y = ScaledArray(y_in, y_scale)

expected = jnp.array([-3.0, 3.0])

out = asfunc(x, y)

assert jnp.allclose(out.aval, expected)