Skip to content

Commit

Permalink
basic interpeter for scaled ops and scaled arrays (#8)
Browse files Browse the repository at this point in the history
* basic interpeter for scaled ops and scaled arrays

* precommit

* mypy fixes

* use ScaledArray.aval

* move registry into core

* Fixing module imports

* imports in top level __init__.py

* linting fixes

* return value if in singleton list

* return value if in singleton list

* comment on reason for multiple vars per scaledarray

---------

Co-authored-by: Paul Balanca <paulb@graphcore.ai>
  • Loading branch information
lyprince and balancap committed Nov 8, 2023
1 parent f8f4a65 commit 8352e80
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax_scaled_arithmetics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from . import lax
from ._version import __version__
from .core import ScaledArray, autoscale # noqa: F401
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import ScaledArray # noqa: F401
from .interpreters import autoscale, register_scaled_op # noqa: F401
4 changes: 4 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,7 @@ def to_array(self, dtype: DTypeLike = None) -> GenericArray:
def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]:
"""Numpy array interface support."""
return np.asarray(self.to_array(dtype))

@property
def aval(self):
return self.data * self.scale
59 changes: 59 additions & 0 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from functools import wraps
from typing import Dict

import jax
from jax import core
from jax._src.util import safe_map

from ..core import ScaledArray

_scaled_ops_registry = {}


def register_scaled_op(lax_func, scaled_func):
_scaled_ops_registry[lax_func] = scaled_func


def autoscale(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
aval_args = safe_map(lambda x: x.aval, args)
# get jaxpr of unscaled graph
closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs)
# convert to scaled graph
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out

return wrapped


def autoscale_jaxpr(jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}

def read(var):
if type(var) is core.Literal:
return var.val
return env[var]

def write(var, val):
env[var] = val

safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

for eqn in jaxpr.eqns:
invals = safe_map(read, eqn.invars)
if eqn.primitive not in _scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
outvals = _scaled_ops_registry[eqn.primitive](*invals)
if not eqn.primitive.multiple_results:
outvals = [outvals]
safe_map(write, eqn.outvars, outvals)

outvals = safe_map(read, jaxpr.outvars)
if len(outvals) == 1:
return outvals[0]
else:
return outvals
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .scaled_ops import * # noqa: F401, F403
15 changes: 15 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from jax import lax

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import ScaledArray


def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)


core.register_scaled_op(lax.mul_p, scaled_mul_p)

__all__ = ["scaled_mul_p"]
54 changes: 54 additions & 0 deletions tests/interpreters/test_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

import chex
import jax
import jax.numpy as jnp

from jax_scaled_arithmetics.core import ScaledArray, autoscale


class AutoScaleInterpreterTests(chex.TestCase):
def test__identity(self):
def func(x):
return x

asfunc = autoscale(func)

scale = jnp.array(1.0)
inputs = jnp.array([1.0, 2.0])
expected = jnp.array([1.0, 2.0])

scaled_inputs = ScaledArray(inputs, scale)
scaled_outputs = asfunc(scaled_inputs)

assert jnp.allclose(scaled_outputs.aval, expected)

jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr

# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray

assert jaxpr.invars[0].aval.shape == inputs.shape
assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == expected.shape
assert jaxpr.outvars[1].aval.shape == ()

def test__mul(self):
def func(x, y):
return x * y

asfunc = autoscale(func)

x_in = jnp.array([-2.0, 2.0])
x_scale = jnp.array(0.5)
x = ScaledArray(x_in, x_scale)

y_in = jnp.array([1.5, 1.5])
y_scale = jnp.array(2.0)
y = ScaledArray(y_in, y_scale)

expected = jnp.array([-3.0, 3.0])

out = asfunc(x, y)

assert jnp.allclose(out.aval, expected)

0 comments on commit 8352e80

Please sign in to comment.