-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
basic interpeter for scaled ops and scaled arrays
- Loading branch information
Showing
5 changed files
with
118 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .autoscale import autoscale |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
import jax | ||
import numpy as np | ||
|
||
from jax import core | ||
from jax._src.util import safe_map | ||
from ..lax import scaled_ops_registry | ||
from functools import wraps | ||
|
||
|
||
def autoscale(fun): | ||
@wraps(fun) | ||
def wrapped(*args, **kwargs): | ||
unscaled_args = safe_map(lambda x: x.to_array() if hasattr(x, "to_array") else x, args) | ||
# get jaxpr of unscaled graph | ||
closed_jaxpr = jax.make_jaxpr(fun)(*unscaled_args, **kwargs) | ||
# convert to scaled graph | ||
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) | ||
return out | ||
|
||
return wrapped | ||
|
||
|
||
def autoscale_jaxpr(jaxpr, consts, *args): | ||
env = {} | ||
|
||
def read(var): | ||
if type(var) is core.Literal: | ||
return var.val | ||
return env[var] | ||
|
||
def write(var, val): | ||
env[var] = val | ||
|
||
safe_map(write, jaxpr.invars, args) | ||
safe_map(write, jaxpr.constvars, consts) | ||
|
||
for eqn in jaxpr.eqns: | ||
invals = safe_map(read, eqn.invars) | ||
if eqn.primitive not in scaled_ops_registry: | ||
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") | ||
outvals = scaled_ops_registry[eqn.primitive](*invals) | ||
if not eqn.primitive.multiple_results: | ||
outvals = [outvals] | ||
safe_map(write, eqn.outvars, outvals) | ||
|
||
safe_map(read, jaxpr.outvars) | ||
|
||
return safe_map(read, jaxpr.outvars) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,6 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
from jax import lax | ||
from .scaled_ops import * | ||
|
||
scaled_ops_registry = {lax.mul_p: scaled_mul} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from ..core import ScaledArray | ||
|
||
|
||
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: | ||
return ScaledArray(A.data * B.data, A.scale * B.scale) | ||
|
||
|
||
__all__ = ["scaled_mul"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray | ||
from jax_scaled_arithmetics.interpreters import autoscale | ||
|
||
|
||
class AutoScaleInterpreterTests(chex.TestCase): | ||
def test__identity(self): | ||
def func(x): | ||
return x | ||
|
||
asfunc = autoscale(func) | ||
|
||
scale = jnp.array(1.0) | ||
inputs = jnp.array([1.0, 2.0]) | ||
expected = jnp.array([1.0, 2.0]) | ||
|
||
scaled_inputs = ScaledArray(inputs, scale) | ||
scaled_outputs = asfunc(scaled_inputs)[0] | ||
|
||
assert jnp.allclose(scaled_outputs.to_array(), expected) | ||
|
||
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr | ||
|
||
assert jaxpr.invars[0].aval.shape == inputs.shape | ||
assert jaxpr.invars[1].aval.shape == () | ||
|
||
assert jaxpr.outvars[0].aval.shape == expected.shape | ||
assert jaxpr.outvars[1].aval.shape == () | ||
|
||
def test__mul(self): | ||
def func(x, y): | ||
return x * y | ||
|
||
asfunc = autoscale(func) | ||
|
||
x_in = jnp.array([-2.0, 2.0]) | ||
x_scale = jnp.array(0.5) | ||
x = ScaledArray(x_in, x_scale) | ||
|
||
y_in = jnp.array([1.5, 1.5]) | ||
y_scale = jnp.array(2.0) | ||
y = ScaledArray(y_in, y_scale) | ||
|
||
expected = jnp.array([-3.0, 3.0]) | ||
|
||
out = asfunc(x, y)[0] | ||
|
||
assert jnp.allclose(out.to_array(), expected) |