Skip to content

Commit

Permalink
basic interpeter for scaled ops and scaled arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
lyprince committed Nov 8, 2023
1 parent d9dfb39 commit aeedd3d
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax_scaled_arithmetics/interpreters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .autoscale import autoscale
50 changes: 50 additions & 0 deletions jax_scaled_arithmetics/interpreters/autoscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

import jax
import numpy as np

from jax import core
from jax._src.util import safe_map
from ..lax import scaled_ops_registry
from functools import wraps


def autoscale(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
unscaled_args = safe_map(lambda x: x.to_array() if hasattr(x, "to_array") else x, args)
# get jaxpr of unscaled graph
closed_jaxpr = jax.make_jaxpr(fun)(*unscaled_args, **kwargs)
# convert to scaled graph
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out

return wrapped


def autoscale_jaxpr(jaxpr, consts, *args):
env = {}

def read(var):
if type(var) is core.Literal:
return var.val
return env[var]

def write(var, val):
env[var] = val

safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

for eqn in jaxpr.eqns:
invals = safe_map(read, eqn.invars)
if eqn.primitive not in scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
outvals = scaled_ops_registry[eqn.primitive](*invals)
if not eqn.primitive.multiple_results:
outvals = [outvals]
safe_map(write, eqn.outvars, outvals)

safe_map(read, jaxpr.outvars)

return safe_map(read, jaxpr.outvars)
5 changes: 5 additions & 0 deletions jax_scaled_arithmetics/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from jax import lax
from .scaled_ops import *

scaled_ops_registry = {lax.mul_p: scaled_mul}
8 changes: 8 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ..core import ScaledArray


def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)


__all__ = ["scaled_mul"]
53 changes: 53 additions & 0 deletions tests/interpreters/test_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

import chex
import jax
import jax.numpy as jnp

from jax_scaled_arithmetics.core import ScaledArray
from jax_scaled_arithmetics.interpreters import autoscale


class AutoScaleInterpreterTests(chex.TestCase):
def test__identity(self):
def func(x):
return x

asfunc = autoscale(func)

scale = jnp.array(1.0)
inputs = jnp.array([1.0, 2.0])
expected = jnp.array([1.0, 2.0])

scaled_inputs = ScaledArray(inputs, scale)
scaled_outputs = asfunc(scaled_inputs)[0]

assert jnp.allclose(scaled_outputs.to_array(), expected)

jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr

assert jaxpr.invars[0].aval.shape == inputs.shape
assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == expected.shape
assert jaxpr.outvars[1].aval.shape == ()

def test__mul(self):
def func(x, y):
return x * y

asfunc = autoscale(func)

x_in = jnp.array([-2.0, 2.0])
x_scale = jnp.array(0.5)
x = ScaledArray(x_in, x_scale)

y_in = jnp.array([1.5, 1.5])
y_scale = jnp.array(2.0)
y = ScaledArray(y_in, y_scale)

expected = jnp.array([-3.0, 3.0])

out = asfunc(x, y)[0]

assert jnp.allclose(out.to_array(), expected)

0 comments on commit aeedd3d

Please sign in to comment.