From aeedd3d48e220272d7b77fc812693479212bd59e Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 12:06:45 +0000 Subject: [PATCH 01/11] basic interpeter for scaled ops and scaled arrays --- .../interpreters/__init__.py | 2 + .../interpreters/autoscale.py | 50 +++++++++++++++++ jax_scaled_arithmetics/lax/__init__.py | 5 ++ jax_scaled_arithmetics/lax/scaled_ops.py | 8 +++ tests/interpreters/test_interpreter.py | 53 +++++++++++++++++++ 5 files changed, 118 insertions(+) create mode 100644 jax_scaled_arithmetics/interpreters/__init__.py create mode 100644 jax_scaled_arithmetics/interpreters/autoscale.py create mode 100644 jax_scaled_arithmetics/lax/scaled_ops.py create mode 100644 tests/interpreters/test_interpreter.py diff --git a/jax_scaled_arithmetics/interpreters/__init__.py b/jax_scaled_arithmetics/interpreters/__init__.py new file mode 100644 index 0000000..24a370b --- /dev/null +++ b/jax_scaled_arithmetics/interpreters/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from .autoscale import autoscale diff --git a/jax_scaled_arithmetics/interpreters/autoscale.py b/jax_scaled_arithmetics/interpreters/autoscale.py new file mode 100644 index 0000000..6d003d9 --- /dev/null +++ b/jax_scaled_arithmetics/interpreters/autoscale.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +import jax +import numpy as np + +from jax import core +from jax._src.util import safe_map +from ..lax import scaled_ops_registry +from functools import wraps + + +def autoscale(fun): + @wraps(fun) + def wrapped(*args, **kwargs): + unscaled_args = safe_map(lambda x: x.to_array() if hasattr(x, "to_array") else x, args) + # get jaxpr of unscaled graph + closed_jaxpr = jax.make_jaxpr(fun)(*unscaled_args, **kwargs) + # convert to scaled graph + out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) + return out + + return wrapped + + +def autoscale_jaxpr(jaxpr, consts, *args): + env = {} + + def read(var): + if type(var) is core.Literal: + return var.val + return env[var] + + def write(var, val): + env[var] = val + + safe_map(write, jaxpr.invars, args) + safe_map(write, jaxpr.constvars, consts) + + for eqn in jaxpr.eqns: + invals = safe_map(read, eqn.invars) + if eqn.primitive not in scaled_ops_registry: + raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") + outvals = scaled_ops_registry[eqn.primitive](*invals) + if not eqn.primitive.multiple_results: + outvals = [outvals] + safe_map(write, eqn.outvars, outvals) + + safe_map(read, jaxpr.outvars) + + return safe_map(read, jaxpr.outvars) diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index 9dc9fcb..1fdc265 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -1 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +from jax import lax +from .scaled_ops import * + +scaled_ops_registry = {lax.mul_p: scaled_mul} diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py new file mode 100644 index 0000000..07532ed --- /dev/null +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -0,0 +1,8 @@ +from ..core import ScaledArray + + +def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: + return ScaledArray(A.data * B.data, A.scale * B.scale) + + +__all__ = ["scaled_mul"] diff --git a/tests/interpreters/test_interpreter.py b/tests/interpreters/test_interpreter.py new file mode 100644 index 0000000..08cbce9 --- /dev/null +++ b/tests/interpreters/test_interpreter.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +import chex +import jax +import jax.numpy as jnp + +from jax_scaled_arithmetics.core import ScaledArray +from jax_scaled_arithmetics.interpreters import autoscale + + +class AutoScaleInterpreterTests(chex.TestCase): + def test__identity(self): + def func(x): + return x + + asfunc = autoscale(func) + + scale = jnp.array(1.0) + inputs = jnp.array([1.0, 2.0]) + expected = jnp.array([1.0, 2.0]) + + scaled_inputs = ScaledArray(inputs, scale) + scaled_outputs = asfunc(scaled_inputs)[0] + + assert jnp.allclose(scaled_outputs.to_array(), expected) + + jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr + + assert jaxpr.invars[0].aval.shape == inputs.shape + assert jaxpr.invars[1].aval.shape == () + + assert jaxpr.outvars[0].aval.shape == expected.shape + assert jaxpr.outvars[1].aval.shape == () + + def test__mul(self): + def func(x, y): + return x * y + + asfunc = autoscale(func) + + x_in = jnp.array([-2.0, 2.0]) + x_scale = jnp.array(0.5) + x = ScaledArray(x_in, x_scale) + + y_in = jnp.array([1.5, 1.5]) + y_scale = jnp.array(2.0) + y = ScaledArray(y_in, y_scale) + + expected = jnp.array([-3.0, 3.0]) + + out = asfunc(x, y)[0] + + assert jnp.allclose(out.to_array(), expected) From 62b2289bc629cc324cb1d82d103ea772372f62ef Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 12:18:52 +0000 Subject: [PATCH 02/11] precommit --- jax_scaled_arithmetics/interpreters/autoscale.py | 5 +++-- jax_scaled_arithmetics/lax/__init__.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/jax_scaled_arithmetics/interpreters/autoscale.py b/jax_scaled_arithmetics/interpreters/autoscale.py index 6d003d9..65b2dc9 100644 --- a/jax_scaled_arithmetics/interpreters/autoscale.py +++ b/jax_scaled_arithmetics/interpreters/autoscale.py @@ -1,12 +1,13 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import wraps + import jax import numpy as np - from jax import core from jax._src.util import safe_map + from ..lax import scaled_ops_registry -from functools import wraps def autoscale(fun): diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index 1fdc265..5d97aa8 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from jax import lax + from .scaled_ops import * scaled_ops_registry = {lax.mul_p: scaled_mul} From 71df8668043da0cc6d360b115232f9c63b6323f1 Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 12:30:41 +0000 Subject: [PATCH 03/11] mypy fixes --- jax_scaled_arithmetics/interpreters/__init__.py | 2 +- jax_scaled_arithmetics/interpreters/autoscale.py | 7 ++++--- jax_scaled_arithmetics/lax/__init__.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/jax_scaled_arithmetics/interpreters/__init__.py b/jax_scaled_arithmetics/interpreters/__init__.py index 24a370b..d492077 100644 --- a/jax_scaled_arithmetics/interpreters/__init__.py +++ b/jax_scaled_arithmetics/interpreters/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .autoscale import autoscale +from .autoscale import autoscale # noqa: F401 diff --git a/jax_scaled_arithmetics/interpreters/autoscale.py b/jax_scaled_arithmetics/interpreters/autoscale.py index 65b2dc9..7645de3 100644 --- a/jax_scaled_arithmetics/interpreters/autoscale.py +++ b/jax_scaled_arithmetics/interpreters/autoscale.py @@ -1,12 +1,13 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from functools import wraps +from typing import Dict import jax -import numpy as np from jax import core from jax._src.util import safe_map +from ..core import ScaledArray from ..lax import scaled_ops_registry @@ -24,7 +25,7 @@ def wrapped(*args, **kwargs): def autoscale_jaxpr(jaxpr, consts, *args): - env = {} + env: Dict[core.Var, ScaledArray] = {} def read(var): if type(var) is core.Literal: @@ -43,7 +44,7 @@ def write(var, val): raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") outvals = scaled_ops_registry[eqn.primitive](*invals) if not eqn.primitive.multiple_results: - outvals = [outvals] + outvals = [outvals] # type: ignore safe_map(write, eqn.outvars, outvals) safe_map(read, jaxpr.outvars) diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index 5d97aa8..ac61a4a 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -2,6 +2,6 @@ from jax import lax -from .scaled_ops import * +from .scaled_ops import scaled_mul scaled_ops_registry = {lax.mul_p: scaled_mul} From 0b43a916a795a4b7723f15e5e9b029687e9761e4 Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 14:51:10 +0000 Subject: [PATCH 04/11] use ScaledArray.aval --- jax_scaled_arithmetics/core/datatype.py | 4 ++++ jax_scaled_arithmetics/interpreters/autoscale.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 759e2ba..cfcc518 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -78,3 +78,7 @@ def to_array(self, dtype: DTypeLike = None) -> GenericArray: def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]: """Numpy array interface support.""" return np.asarray(self.to_array(dtype)) + + @property + def aval(self): + return self.data * self.scale diff --git a/jax_scaled_arithmetics/interpreters/autoscale.py b/jax_scaled_arithmetics/interpreters/autoscale.py index 7645de3..abb5f37 100644 --- a/jax_scaled_arithmetics/interpreters/autoscale.py +++ b/jax_scaled_arithmetics/interpreters/autoscale.py @@ -14,9 +14,9 @@ def autoscale(fun): @wraps(fun) def wrapped(*args, **kwargs): - unscaled_args = safe_map(lambda x: x.to_array() if hasattr(x, "to_array") else x, args) + aval_args = safe_map(lambda x: x.aval, args) # get jaxpr of unscaled graph - closed_jaxpr = jax.make_jaxpr(fun)(*unscaled_args, **kwargs) + closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs) # convert to scaled graph out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) return out From 404a2f65f7418dba2fa72bdd80db408370971754 Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 15:59:15 +0000 Subject: [PATCH 05/11] move registry into core --- jax_scaled_arithmetics/core/__init__.py | 1 + .../autoscale.py => core/interpreters.py} | 11 ++++++++--- jax_scaled_arithmetics/interpreters/__init__.py | 2 -- jax_scaled_arithmetics/lax/__init__.py | 7 +------ jax_scaled_arithmetics/lax/scaled_ops.py | 15 ++++++++++++--- tests/interpreters/test_interpreter.py | 3 +-- 6 files changed, 23 insertions(+), 16 deletions(-) rename jax_scaled_arithmetics/{interpreters/autoscale.py => core/interpreters.py} (83%) delete mode 100644 jax_scaled_arithmetics/interpreters/__init__.py diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 61e13fa..d6f4c6a 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from .datatype import ScaledArray # noqa: F401 +from .interpreters import autoscale, register_scaled_op # noqa: F401 diff --git a/jax_scaled_arithmetics/interpreters/autoscale.py b/jax_scaled_arithmetics/core/interpreters.py similarity index 83% rename from jax_scaled_arithmetics/interpreters/autoscale.py rename to jax_scaled_arithmetics/core/interpreters.py index abb5f37..603382e 100644 --- a/jax_scaled_arithmetics/interpreters/autoscale.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -8,7 +8,12 @@ from jax._src.util import safe_map from ..core import ScaledArray -from ..lax import scaled_ops_registry + +_scaled_ops_registry: Dict[callable, callable] = {} + + +def register_scaled_op(lax_func, scaled_func): + _scaled_ops_registry[lax_func] = scaled_func def autoscale(fun): @@ -40,9 +45,9 @@ def write(var, val): for eqn in jaxpr.eqns: invals = safe_map(read, eqn.invars) - if eqn.primitive not in scaled_ops_registry: + if eqn.primitive not in _scaled_ops_registry: raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") - outvals = scaled_ops_registry[eqn.primitive](*invals) + outvals = _scaled_ops_registry[eqn.primitive](*invals) if not eqn.primitive.multiple_results: outvals = [outvals] # type: ignore safe_map(write, eqn.outvars, outvals) diff --git a/jax_scaled_arithmetics/interpreters/__init__.py b/jax_scaled_arithmetics/interpreters/__init__.py deleted file mode 100644 index d492077..0000000 --- a/jax_scaled_arithmetics/interpreters/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .autoscale import autoscale # noqa: F401 diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index ac61a4a..8e2df8b 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -1,7 +1,2 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. - -from jax import lax - -from .scaled_ops import scaled_mul - -scaled_ops_registry = {lax.mul_p: scaled_mul} +from .scaled_ops import * diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 07532ed..ef86b17 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -1,8 +1,17 @@ -from ..core import ScaledArray +from core import ScaledArray +import core +from jax import lax +from functools import partial -def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: +# Tried as decorator too + + +# @partial(core.register_scaled_op, lax_func=lax.mul) +def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray: return ScaledArray(A.data * B.data, A.scale * B.scale) -__all__ = ["scaled_mul"] +core.register_scaled_op(lax.mul_p, scaled_mul_p) + +__all__ = ["scaled_mul_p"] diff --git a/tests/interpreters/test_interpreter.py b/tests/interpreters/test_interpreter.py index 08cbce9..0435d31 100644 --- a/tests/interpreters/test_interpreter.py +++ b/tests/interpreters/test_interpreter.py @@ -4,8 +4,7 @@ import jax import jax.numpy as jnp -from jax_scaled_arithmetics.core import ScaledArray -from jax_scaled_arithmetics.interpreters import autoscale +from jax_scaled_arithmetics.core import ScaledArray, autoscale class AutoScaleInterpreterTests(chex.TestCase): From 3f9d5ae6f857582d1b82bd439f6ecf92e0e2f530 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 8 Nov 2023 16:10:33 +0000 Subject: [PATCH 06/11] Fixing module imports --- jax_scaled_arithmetics/lax/scaled_ops.py | 5 +++-- tests/interpreters/test_interpreter.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index ef86b17..072814a 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -1,5 +1,6 @@ -from core import ScaledArray -import core +from jax_scaled_arithmetics import core +from jax_scaled_arithmetics.core import ScaledArray + from jax import lax from functools import partial diff --git a/tests/interpreters/test_interpreter.py b/tests/interpreters/test_interpreter.py index 0435d31..7af0899 100644 --- a/tests/interpreters/test_interpreter.py +++ b/tests/interpreters/test_interpreter.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from jax_scaled_arithmetics.core import ScaledArray, autoscale - +import jax_scaled_arithmetics.lax class AutoScaleInterpreterTests(chex.TestCase): def test__identity(self): From e5f4707adaeb889a96bf33c08353e75f50986c93 Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 16:40:04 +0000 Subject: [PATCH 07/11] imports in top level __init__.py --- jax_scaled_arithmetics/__init__.py | 3 +++ tests/interpreters/test_interpreter.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/jax_scaled_arithmetics/__init__.py b/jax_scaled_arithmetics/__init__.py index 803aee4..602ec1c 100644 --- a/jax_scaled_arithmetics/__init__.py +++ b/jax_scaled_arithmetics/__init__.py @@ -1,2 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from ._version import __version__ + +from . import lax +from .core import ScaledArray, autoscale diff --git a/tests/interpreters/test_interpreter.py b/tests/interpreters/test_interpreter.py index 7af0899..0435d31 100644 --- a/tests/interpreters/test_interpreter.py +++ b/tests/interpreters/test_interpreter.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from jax_scaled_arithmetics.core import ScaledArray, autoscale -import jax_scaled_arithmetics.lax + class AutoScaleInterpreterTests(chex.TestCase): def test__identity(self): From 7ae1648403d61dc3d07362f1d560b5a81680a20c Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 16:48:22 +0000 Subject: [PATCH 08/11] linting fixes --- jax_scaled_arithmetics/__init__.py | 5 ++--- jax_scaled_arithmetics/core/interpreters.py | 4 ++-- jax_scaled_arithmetics/lax/__init__.py | 2 +- jax_scaled_arithmetics/lax/scaled_ops.py | 9 +++------ 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/jax_scaled_arithmetics/__init__.py b/jax_scaled_arithmetics/__init__.py index 602ec1c..5a42b7e 100644 --- a/jax_scaled_arithmetics/__init__.py +++ b/jax_scaled_arithmetics/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from ._version import __version__ - from . import lax -from .core import ScaledArray, autoscale +from ._version import __version__ +from .core import ScaledArray, autoscale # noqa: F401 diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 603382e..075240d 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -9,7 +9,7 @@ from ..core import ScaledArray -_scaled_ops_registry: Dict[callable, callable] = {} +_scaled_ops_registry = {} def register_scaled_op(lax_func, scaled_func): @@ -49,7 +49,7 @@ def write(var, val): raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") outvals = _scaled_ops_registry[eqn.primitive](*invals) if not eqn.primitive.multiple_results: - outvals = [outvals] # type: ignore + outvals = [outvals] safe_map(write, eqn.outvars, outvals) safe_map(read, jaxpr.outvars) diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index 8e2df8b..65b52cc 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .scaled_ops import * +from .scaled_ops import * # noqa: F401, F403 diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 072814a..d1106f8 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -1,14 +1,11 @@ -from jax_scaled_arithmetics import core -from jax_scaled_arithmetics.core import ScaledArray +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. from jax import lax -from functools import partial - -# Tried as decorator too +from jax_scaled_arithmetics import core +from jax_scaled_arithmetics.core import ScaledArray -# @partial(core.register_scaled_op, lax_func=lax.mul) def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray: return ScaledArray(A.data * B.data, A.scale * B.scale) From f2fbd51faa67e12516c88569939d765bf52e98e7 Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 17:02:03 +0000 Subject: [PATCH 09/11] return value if in singleton list --- jax_scaled_arithmetics/core/interpreters.py | 8 +++++--- tests/interpreters/test_interpreter.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 075240d..503cb1d 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -52,6 +52,8 @@ def write(var, val): outvals = [outvals] safe_map(write, eqn.outvars, outvals) - safe_map(read, jaxpr.outvars) - - return safe_map(read, jaxpr.outvars) + outvals = safe_map(read, jaxpr.outvars) + if len(outvals) == 1: + return outvals[0] + else: + return outvals diff --git a/tests/interpreters/test_interpreter.py b/tests/interpreters/test_interpreter.py index 0435d31..5c26758 100644 --- a/tests/interpreters/test_interpreter.py +++ b/tests/interpreters/test_interpreter.py @@ -19,7 +19,7 @@ def func(x): expected = jnp.array([1.0, 2.0]) scaled_inputs = ScaledArray(inputs, scale) - scaled_outputs = asfunc(scaled_inputs)[0] + scaled_outputs = asfunc(scaled_inputs) assert jnp.allclose(scaled_outputs.to_array(), expected) @@ -47,6 +47,6 @@ def func(x, y): expected = jnp.array([-3.0, 3.0]) - out = asfunc(x, y)[0] + out = asfunc(x, y) assert jnp.allclose(out.to_array(), expected) From f5995993056bb48f49843190c6c4c7bfe46856c5 Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 17:03:24 +0000 Subject: [PATCH 10/11] return value if in singleton list --- tests/interpreters/test_interpreter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/interpreters/test_interpreter.py b/tests/interpreters/test_interpreter.py index 5c26758..015ece5 100644 --- a/tests/interpreters/test_interpreter.py +++ b/tests/interpreters/test_interpreter.py @@ -50,3 +50,5 @@ def func(x, y): out = asfunc(x, y) assert jnp.allclose(out.to_array(), expected) + + breakpoint() From 9a850b943b3cee9d110c7831b28afbf705b08c47 Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 17:08:07 +0000 Subject: [PATCH 11/11] comment on reason for multiple vars per scaledarray --- tests/interpreters/test_interpreter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/interpreters/test_interpreter.py b/tests/interpreters/test_interpreter.py index 015ece5..5bbb4ad 100644 --- a/tests/interpreters/test_interpreter.py +++ b/tests/interpreters/test_interpreter.py @@ -21,10 +21,12 @@ def func(x): scaled_inputs = ScaledArray(inputs, scale) scaled_outputs = asfunc(scaled_inputs) - assert jnp.allclose(scaled_outputs.to_array(), expected) + assert jnp.allclose(scaled_outputs.aval, expected) jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr + # Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray + assert jaxpr.invars[0].aval.shape == inputs.shape assert jaxpr.invars[1].aval.shape == () @@ -49,6 +51,4 @@ def func(x, y): out = asfunc(x, y) - assert jnp.allclose(out.to_array(), expected) - - breakpoint() + assert jnp.allclose(out.aval, expected)