Skip to content

Commit

Permalink
Fix ScaledArray aval and create factory method. (#10)
Browse files Browse the repository at this point in the history
* `ScaledArray.aval` returning JAX `ShapedArray`, like JAX API;
* Factory method `scaled_array`, similar to `jnp.array`;

The latter makes testing code simpler & clearer.
  • Loading branch information
balancap authored Nov 10, 2023
1 parent b90436f commit 85bf6d2
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 45 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from . import lax
from ._version import __version__
from .core import ScaledArray, autoscale # noqa: F401
from .core import ScaledArray, autoscale, scaled_array # noqa: F401
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import ScaledArray # noqa: F401
from .datatype import ScaledArray, scaled_array # noqa: F401
from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
25 changes: 22 additions & 3 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import Any, Union

import jax
import jax.numpy as jnp
import numpy as np
from chex import Shape
from jax.core import ShapedArray
from jax.tree_util import register_pytree_node_class
from numpy.typing import DTypeLike, NDArray
from numpy.typing import ArrayLike, DTypeLike, NDArray

GenericArray = Union[jax.Array, np.ndarray]

Expand Down Expand Up @@ -80,5 +82,22 @@ def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]:
return np.asarray(self.to_array(dtype))

@property
def aval(self):
return self.data * self.scale
def aval(self) -> ShapedArray:
"""Abstract value of the scaled array, i.e. shape and dtype."""
return ShapedArray(self.data.shape, self.data.dtype)


def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray:
"""ScaledArray (helper) factory method, similar to `(j)np.array`.
Args:
data: Main data/values.
scale: Scale tensor.
dtype: Optional dtype to use for the data.
npapi: Numpy API to use.
Returns:
Scaled array instance.
"""
data = npapi.asarray(data, dtype=dtype)
scale = npapi.asarray(scale)
return ScaledArray(data, scale)
59 changes: 39 additions & 20 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,73 @@
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax.core import ShapedArray

from jax_scaled_arithmetics.core import ScaledArray
from jax_scaled_arithmetics import ScaledArray, scaled_array


class ScaledArrayDataclassTests(chex.TestCase):
@parameterized.parameters(
{"npb": np},
{"npb": jnp},
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__init__multi_numpy_backend(self, npb):
sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float32), scale=npb.array(1))
assert isinstance(sarr.data, npb.ndarray)
assert isinstance(sarr.scale, npb.ndarray)
def test__scaled_array__init__multi_numpy_backend(self, npapi):
sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(1))
assert isinstance(sarr.data, npapi.ndarray)
assert isinstance(sarr.scale, npapi.ndarray)
assert sarr.scale.shape == ()

def test__scaled_array__basic_properties(self):
sarr = ScaledArray(data=jnp.array([1.0, 2.0]), scale=jnp.array(1))
@parameterized.parameters(
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__factory_method__multi_numpy_backend(self, npapi):
sarr = scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16, npapi=npapi)
assert isinstance(sarr, ScaledArray)
assert isinstance(sarr.data, npapi.ndarray)
assert isinstance(sarr.scale, npapi.ndarray)
assert sarr.data.dtype == ShapedArray((2,), np.float16)
assert sarr.scale.shape == ()
npt.assert_array_almost_equal(sarr, [3, 6])

@parameterized.parameters(
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__basic_properties(self, npapi):
sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(1))
assert sarr.dtype == np.float32
assert sarr.shape == (2,)
assert sarr.aval == ShapedArray((2,), np.float32)

@parameterized.parameters(
{"npb": np},
{"npb": jnp},
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__to_array__multi_numpy_backend(self, npb):
sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float16), scale=npb.array(3))
def test__scaled_array__to_array__multi_numpy_backend(self, npapi):
sarr = scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16, npapi=npapi)
# No dtype specified.
out = sarr.to_array()
assert isinstance(out, npb.ndarray)
assert isinstance(out, npapi.ndarray)
assert out.dtype == sarr.dtype
npt.assert_array_equal(out, sarr.data * sarr.scale)
# Custom float dtype.
out = sarr.to_array(dtype=np.float32)
assert isinstance(out, npb.ndarray)
assert isinstance(out, npapi.ndarray)
assert out.dtype == np.float32
npt.assert_array_equal(out, sarr.data * sarr.scale)
# Custom int dtype.
out = sarr.to_array(dtype=np.int8)
assert isinstance(out, npb.ndarray)
assert isinstance(out, npapi.ndarray)
assert out.dtype == np.int8
npt.assert_array_equal(out, sarr.data * sarr.scale)

@parameterized.parameters(
{"npb": np},
{"npb": jnp},
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__numpy_array_interface(self, npb):
sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float32), scale=npb.array(3))
def test__scaled_array__numpy_array_interface(self, npapi):
sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(3))
out = np.asarray(sarr)
assert isinstance(out, np.ndarray)
npt.assert_array_equal(out, sarr.data * sarr.scale)
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import chex
import jax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, autoscale
from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array


class AutoScaleInterpreterTests(chex.TestCase):
Expand All @@ -14,20 +16,16 @@ def func(x):

asfunc = autoscale(func)

scale = jnp.array(1.0)
inputs = jnp.array([1.0, 2.0])
expected = jnp.array([1.0, 2.0])

scaled_inputs = ScaledArray(inputs, scale)
scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32)
scaled_outputs = asfunc(scaled_inputs)
expected = jnp.array([1.0, 2.0])

assert jnp.allclose(scaled_outputs.aval, expected)

assert isinstance(scaled_outputs, ScaledArray)
npt.assert_array_almost_equal(scaled_outputs, expected)
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr

# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray

assert jaxpr.invars[0].aval.shape == inputs.shape
assert jaxpr.invars[0].aval.shape == scaled_inputs.shape
assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == expected.shape
Expand All @@ -39,16 +37,10 @@ def func(x, y):

asfunc = autoscale(func)

x_in = jnp.array([-2.0, 2.0])
x_scale = jnp.array(0.5)
x = ScaledArray(x_in, x_scale)

y_in = jnp.array([1.5, 1.5])
y_scale = jnp.array(2.0)
y = ScaledArray(y_in, y_scale)

x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
expected = jnp.array([-3.0, 3.0])

out = asfunc(x, y)

assert jnp.allclose(out.aval, expected)
assert isinstance(out, ScaledArray)
npt.assert_array_almost_equal(out, expected)

0 comments on commit 85bf6d2

Please sign in to comment.