Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic interpeter for scaled ops and scaled arrays #8

Merged
merged 12 commits into from
Nov 8, 2023
2 changes: 2 additions & 0 deletions jax_scaled_arithmetics/interpreters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .autoscale import autoscale # noqa: F401
52 changes: 52 additions & 0 deletions jax_scaled_arithmetics/interpreters/autoscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how many interpreters we will have! I would except a main one, with config. When we know more, we can maybe consolidate in a single interpreter.py file in core


from functools import wraps
from typing import Dict

import jax
from jax import core
from jax._src.util import safe_map

from ..core import ScaledArray
from ..lax import scaled_ops_registry
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To follow JAX way: I would actually declare the global dict scaled_ops_registry here, and import it in the lax module. Feels weird to do it the other way around.

And I would add a function register_scaled_op to mutate it, so we avoid importing a global variable everywhere (and at some point we can add additional checks)



def autoscale(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
unscaled_args = safe_map(lambda x: x.to_array() if hasattr(x, "to_array") else x, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking: it is probably sub-optimal to call to_array here. Just getting .aval abstract array should be enough for tracing. Then you avoid the potential cost of the allocation of arrays.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning we need to add .aval property to ScaledArray

# get jaxpr of unscaled graph
closed_jaxpr = jax.make_jaxpr(fun)(*unscaled_args, **kwargs)
# convert to scaled graph
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out

return wrapped


def autoscale_jaxpr(jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}

def read(var):
if type(var) is core.Literal:
return var.val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking: we may need to think how we want to handle literals and constants? I believe we should also convert them to ScaledArray

return env[var]

def write(var, val):
env[var] = val

safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

for eqn in jaxpr.eqns:
invals = safe_map(read, eqn.invars)
if eqn.primitive not in scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
outvals = scaled_ops_registry[eqn.primitive](*invals)
if not eqn.primitive.multiple_results:
outvals = [outvals] # type: ignore
safe_map(write, eqn.outvars, outvals)

safe_map(read, jaxpr.outvars)

return safe_map(read, jaxpr.outvars)
6 changes: 6 additions & 0 deletions jax_scaled_arithmetics/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from jax import lax

from .scaled_ops import scaled_mul

scaled_ops_registry = {lax.mul_p: scaled_mul}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above, I would move that to autoscale.py

8 changes: 8 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ..core import ScaledArray


def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)


__all__ = ["scaled_mul"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the call to register_scaled_op(mul_p, scaled_mul) here. It's the pattern used everywhere in the JAX codebase

53 changes: 53 additions & 0 deletions tests/interpreters/test_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

import chex
import jax
import jax.numpy as jnp

from jax_scaled_arithmetics.core import ScaledArray
from jax_scaled_arithmetics.interpreters import autoscale


class AutoScaleInterpreterTests(chex.TestCase):
def test__identity(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker: we can at some point use @chex.variants(with_jit=True, without_jit=True) to run the same unit test with jax.jit added or not. Helps having consistent checks that our logic works in eager/interpreter mode or jitting mode.

def func(x):
return x

asfunc = autoscale(func)

scale = jnp.array(1.0)
inputs = jnp.array([1.0, 2.0])
expected = jnp.array([1.0, 2.0])

scaled_inputs = ScaledArray(inputs, scale)
scaled_outputs = asfunc(scaled_inputs)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do [0]? Feels we are changing the semantic of the function here?


assert jnp.allclose(scaled_outputs.to_array(), expected)

jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr

assert jaxpr.invars[0].aval.shape == inputs.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here a comment to explain why the jaxpr has doubled amount of inputs/outputs, due to scale tensor. Not that obvious if you're not a JAX expert!

assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == expected.shape
assert jaxpr.outvars[1].aval.shape == ()

def test__mul(self):
def func(x, y):
return x * y

asfunc = autoscale(func)

x_in = jnp.array([-2.0, 2.0])
x_scale = jnp.array(0.5)
x = ScaledArray(x_in, x_scale)

y_in = jnp.array([1.5, 1.5])
y_scale = jnp.array(2.0)
y = ScaledArray(y_in, y_scale)

expected = jnp.array([-3.0, 3.0])

out = asfunc(x, y)[0]

assert jnp.allclose(out.to_array(), expected)