-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
basic interpeter for scaled ops and scaled arrays #8
Changes from 4 commits
aeedd3d
9611bb3
62b2289
71df866
0b43a91
404a2f6
3f9d5ae
e5f4707
7ae1648
f2fbd51
f599599
9a850b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .autoscale import autoscale # noqa: F401 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
from functools import wraps | ||
from typing import Dict | ||
|
||
import jax | ||
from jax import core | ||
from jax._src.util import safe_map | ||
|
||
from ..core import ScaledArray | ||
from ..lax import scaled_ops_registry | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To follow JAX way: I would actually declare the global And I would add a function |
||
|
||
|
||
def autoscale(fun): | ||
@wraps(fun) | ||
def wrapped(*args, **kwargs): | ||
unscaled_args = safe_map(lambda x: x.to_array() if hasattr(x, "to_array") else x, args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not blocking: it is probably sub-optimal to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Meaning we need to add |
||
# get jaxpr of unscaled graph | ||
closed_jaxpr = jax.make_jaxpr(fun)(*unscaled_args, **kwargs) | ||
# convert to scaled graph | ||
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) | ||
return out | ||
|
||
return wrapped | ||
|
||
|
||
def autoscale_jaxpr(jaxpr, consts, *args): | ||
env: Dict[core.Var, ScaledArray] = {} | ||
|
||
def read(var): | ||
if type(var) is core.Literal: | ||
return var.val | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not blocking: we may need to think how we want to handle literals and constants? I believe we should also convert them to |
||
return env[var] | ||
|
||
def write(var, val): | ||
env[var] = val | ||
|
||
safe_map(write, jaxpr.invars, args) | ||
safe_map(write, jaxpr.constvars, consts) | ||
|
||
for eqn in jaxpr.eqns: | ||
invals = safe_map(read, eqn.invars) | ||
if eqn.primitive not in scaled_ops_registry: | ||
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") | ||
outvals = scaled_ops_registry[eqn.primitive](*invals) | ||
if not eqn.primitive.multiple_results: | ||
outvals = [outvals] # type: ignore | ||
safe_map(write, eqn.outvars, outvals) | ||
|
||
safe_map(read, jaxpr.outvars) | ||
|
||
return safe_map(read, jaxpr.outvars) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,7 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
from jax import lax | ||
|
||
from .scaled_ops import scaled_mul | ||
|
||
scaled_ops_registry = {lax.mul_p: scaled_mul} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above, I would move that to |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from ..core import ScaledArray | ||
|
||
|
||
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: | ||
return ScaledArray(A.data * B.data, A.scale * B.scale) | ||
|
||
|
||
__all__ = ["scaled_mul"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add the call to |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray | ||
from jax_scaled_arithmetics.interpreters import autoscale | ||
|
||
|
||
class AutoScaleInterpreterTests(chex.TestCase): | ||
def test__identity(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a blocker: we can at some point use |
||
def func(x): | ||
return x | ||
|
||
asfunc = autoscale(func) | ||
|
||
scale = jnp.array(1.0) | ||
inputs = jnp.array([1.0, 2.0]) | ||
expected = jnp.array([1.0, 2.0]) | ||
|
||
scaled_inputs = ScaledArray(inputs, scale) | ||
scaled_outputs = asfunc(scaled_inputs)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to do |
||
|
||
assert jnp.allclose(scaled_outputs.to_array(), expected) | ||
|
||
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr | ||
|
||
assert jaxpr.invars[0].aval.shape == inputs.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe here a comment to explain why the |
||
assert jaxpr.invars[1].aval.shape == () | ||
|
||
assert jaxpr.outvars[0].aval.shape == expected.shape | ||
assert jaxpr.outvars[1].aval.shape == () | ||
|
||
def test__mul(self): | ||
def func(x, y): | ||
return x * y | ||
|
||
asfunc = autoscale(func) | ||
|
||
x_in = jnp.array([-2.0, 2.0]) | ||
x_scale = jnp.array(0.5) | ||
x = ScaledArray(x_in, x_scale) | ||
|
||
y_in = jnp.array([1.5, 1.5]) | ||
y_scale = jnp.array(2.0) | ||
y = ScaledArray(y_in, y_scale) | ||
|
||
expected = jnp.array([-3.0, 3.0]) | ||
|
||
out = asfunc(x, y)[0] | ||
|
||
assert jnp.allclose(out.to_array(), expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how many interpreters we will have! I would except a main one, with config. When we know more, we can maybe consolidate in a single
interpreter.py
file incore