diff --git a/jax_scaled_arithmetics/__init__.py b/jax_scaled_arithmetics/__init__.py index 803aee4..5a42b7e 100644 --- a/jax_scaled_arithmetics/__init__.py +++ b/jax_scaled_arithmetics/__init__.py @@ -1,2 +1,4 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from . import lax from ._version import __version__ +from .core import ScaledArray, autoscale # noqa: F401 diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 61e13fa..d6f4c6a 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from .datatype import ScaledArray # noqa: F401 +from .interpreters import autoscale, register_scaled_op # noqa: F401 diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 759e2ba..cfcc518 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -78,3 +78,7 @@ def to_array(self, dtype: DTypeLike = None) -> GenericArray: def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]: """Numpy array interface support.""" return np.asarray(self.to_array(dtype)) + + @property + def aval(self): + return self.data * self.scale diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py new file mode 100644 index 0000000..503cb1d --- /dev/null +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -0,0 +1,59 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +from functools import wraps +from typing import Dict + +import jax +from jax import core +from jax._src.util import safe_map + +from ..core import ScaledArray + +_scaled_ops_registry = {} + + +def register_scaled_op(lax_func, scaled_func): + _scaled_ops_registry[lax_func] = scaled_func + + +def autoscale(fun): + @wraps(fun) + def wrapped(*args, **kwargs): + aval_args = safe_map(lambda x: x.aval, args) + # get jaxpr of unscaled graph + closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs) + # convert to scaled graph + out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) + return out + + return wrapped + + +def autoscale_jaxpr(jaxpr, consts, *args): + env: Dict[core.Var, ScaledArray] = {} + + def read(var): + if type(var) is core.Literal: + return var.val + return env[var] + + def write(var, val): + env[var] = val + + safe_map(write, jaxpr.invars, args) + safe_map(write, jaxpr.constvars, consts) + + for eqn in jaxpr.eqns: + invals = safe_map(read, eqn.invars) + if eqn.primitive not in _scaled_ops_registry: + raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") + outvals = _scaled_ops_registry[eqn.primitive](*invals) + if not eqn.primitive.multiple_results: + outvals = [outvals] + safe_map(write, eqn.outvars, outvals) + + outvals = safe_map(read, jaxpr.outvars) + if len(outvals) == 1: + return outvals[0] + else: + return outvals diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index 9dc9fcb..65b52cc 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -1 +1,2 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from .scaled_ops import * # noqa: F401, F403 diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py new file mode 100644 index 0000000..d1106f8 --- /dev/null +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +from jax import lax + +from jax_scaled_arithmetics import core +from jax_scaled_arithmetics.core import ScaledArray + + +def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray: + return ScaledArray(A.data * B.data, A.scale * B.scale) + + +core.register_scaled_op(lax.mul_p, scaled_mul_p) + +__all__ = ["scaled_mul_p"] diff --git a/tests/interpreters/test_interpreter.py b/tests/interpreters/test_interpreter.py new file mode 100644 index 0000000..5bbb4ad --- /dev/null +++ b/tests/interpreters/test_interpreter.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +import chex +import jax +import jax.numpy as jnp + +from jax_scaled_arithmetics.core import ScaledArray, autoscale + + +class AutoScaleInterpreterTests(chex.TestCase): + def test__identity(self): + def func(x): + return x + + asfunc = autoscale(func) + + scale = jnp.array(1.0) + inputs = jnp.array([1.0, 2.0]) + expected = jnp.array([1.0, 2.0]) + + scaled_inputs = ScaledArray(inputs, scale) + scaled_outputs = asfunc(scaled_inputs) + + assert jnp.allclose(scaled_outputs.aval, expected) + + jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr + + # Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray + + assert jaxpr.invars[0].aval.shape == inputs.shape + assert jaxpr.invars[1].aval.shape == () + + assert jaxpr.outvars[0].aval.shape == expected.shape + assert jaxpr.outvars[1].aval.shape == () + + def test__mul(self): + def func(x, y): + return x * y + + asfunc = autoscale(func) + + x_in = jnp.array([-2.0, 2.0]) + x_scale = jnp.array(0.5) + x = ScaledArray(x_in, x_scale) + + y_in = jnp.array([1.5, 1.5]) + y_scale = jnp.array(2.0) + y = ScaledArray(y_in, y_scale) + + expected = jnp.array([-3.0, 3.0]) + + out = asfunc(x, y) + + assert jnp.allclose(out.aval, expected)