-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
basic interpeter for scaled ops and scaled arrays #8
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
aeedd3d
basic interpeter for scaled ops and scaled arrays
lyprince 9611bb3
Merge branch 'main' of github.com:graphcore-research/jax-scaled-arith…
lyprince 62b2289
precommit
lyprince 71df866
mypy fixes
lyprince 0b43a91
use ScaledArray.aval
lyprince 404a2f6
move registry into core
lyprince 3f9d5ae
Fixing module imports
balancap e5f4707
imports in top level __init__.py
lyprince 7ae1648
linting fixes
lyprince f2fbd51
return value if in singleton list
lyprince f599599
return value if in singleton list
lyprince 9a850b9
comment on reason for multiple vars per scaledarray
lyprince File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from . import lax | ||
from ._version import __version__ | ||
from .core import ScaledArray, autoscale # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .datatype import ScaledArray # noqa: F401 | ||
from .interpreters import autoscale, register_scaled_op # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
from functools import wraps | ||
from typing import Dict | ||
|
||
import jax | ||
from jax import core | ||
from jax._src.util import safe_map | ||
|
||
from ..core import ScaledArray | ||
|
||
_scaled_ops_registry = {} | ||
|
||
|
||
def register_scaled_op(lax_func, scaled_func): | ||
_scaled_ops_registry[lax_func] = scaled_func | ||
|
||
|
||
def autoscale(fun): | ||
@wraps(fun) | ||
def wrapped(*args, **kwargs): | ||
aval_args = safe_map(lambda x: x.aval, args) | ||
# get jaxpr of unscaled graph | ||
closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs) | ||
# convert to scaled graph | ||
out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) | ||
return out | ||
|
||
return wrapped | ||
|
||
|
||
def autoscale_jaxpr(jaxpr, consts, *args): | ||
env: Dict[core.Var, ScaledArray] = {} | ||
|
||
def read(var): | ||
if type(var) is core.Literal: | ||
return var.val | ||
return env[var] | ||
|
||
def write(var, val): | ||
env[var] = val | ||
|
||
safe_map(write, jaxpr.invars, args) | ||
safe_map(write, jaxpr.constvars, consts) | ||
|
||
for eqn in jaxpr.eqns: | ||
invals = safe_map(read, eqn.invars) | ||
if eqn.primitive not in _scaled_ops_registry: | ||
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") | ||
outvals = _scaled_ops_registry[eqn.primitive](*invals) | ||
if not eqn.primitive.multiple_results: | ||
outvals = [outvals] | ||
safe_map(write, eqn.outvars, outvals) | ||
|
||
outvals = safe_map(read, jaxpr.outvars) | ||
if len(outvals) == 1: | ||
return outvals[0] | ||
else: | ||
return outvals |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .scaled_ops import * # noqa: F401, F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
from jax import lax | ||
|
||
from jax_scaled_arithmetics import core | ||
from jax_scaled_arithmetics.core import ScaledArray | ||
|
||
|
||
def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray: | ||
return ScaledArray(A.data * B.data, A.scale * B.scale) | ||
|
||
|
||
core.register_scaled_op(lax.mul_p, scaled_mul_p) | ||
|
||
__all__ = ["scaled_mul_p"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray, autoscale | ||
|
||
|
||
class AutoScaleInterpreterTests(chex.TestCase): | ||
def test__identity(self): | ||
def func(x): | ||
return x | ||
|
||
asfunc = autoscale(func) | ||
|
||
scale = jnp.array(1.0) | ||
inputs = jnp.array([1.0, 2.0]) | ||
expected = jnp.array([1.0, 2.0]) | ||
|
||
scaled_inputs = ScaledArray(inputs, scale) | ||
scaled_outputs = asfunc(scaled_inputs) | ||
|
||
assert jnp.allclose(scaled_outputs.aval, expected) | ||
|
||
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr | ||
|
||
# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray | ||
|
||
assert jaxpr.invars[0].aval.shape == inputs.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe here a comment to explain why the |
||
assert jaxpr.invars[1].aval.shape == () | ||
|
||
assert jaxpr.outvars[0].aval.shape == expected.shape | ||
assert jaxpr.outvars[1].aval.shape == () | ||
|
||
def test__mul(self): | ||
def func(x, y): | ||
return x * y | ||
|
||
asfunc = autoscale(func) | ||
|
||
x_in = jnp.array([-2.0, 2.0]) | ||
x_scale = jnp.array(0.5) | ||
x = ScaledArray(x_in, x_scale) | ||
|
||
y_in = jnp.array([1.5, 1.5]) | ||
y_scale = jnp.array(2.0) | ||
y = ScaledArray(y_in, y_scale) | ||
|
||
expected = jnp.array([-3.0, 3.0]) | ||
|
||
out = asfunc(x, y) | ||
|
||
assert jnp.allclose(out.aval, expected) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a blocker: we can at some point use
@chex.variants(with_jit=True, without_jit=True)
to run the same unit test withjax.jit
added or not. Helps having consistent checks that our logic works in eager/interpreter mode or jitting mode.