diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 22692c4..37890b6 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .datatype import ScaledArray, scaled_array # noqa: F401 +from .datatype import DTypeLike, ScaledArray, Shape, scaled_array # noqa: F401 from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401 diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 3e32bac..9810535 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from functools import wraps -from typing import Dict +from typing import Any, Dict import jax from jax import core @@ -9,31 +9,48 @@ from ..core import ScaledArray -_scaled_ops_registry = {} +_scaled_ops_registry: Dict[core.Primitive, Any] = {} -def register_scaled_op(lax_func, scaled_func): - _scaled_ops_registry[lax_func] = scaled_func +def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None: + """Register the scaled translation of JAX primitive. + Raises an error if a scaled translation is already existing for this primitive. -def _get_lax_prim(scaled_func): + Args: + prim: JAX primitive. + scaled_fund: Scaled translation of the primitive. With the same interface. + """ + assert isinstance(prim, core.Primitive) + if prim in _scaled_ops_registry: + raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.") + _scaled_ops_registry[prim] = scaled_func + + +def _get_lax_prim(scaled_func: Any) -> core.Primitive: try: - op = getattr(jax.lax, scaled_func.__name__.replace("scaled_", "")) + prim_name = scaled_func.__name__.replace("scaled_", "") + "_p" + prim = getattr(jax.lax, prim_name) except AttributeError: - raise AttributeError(f"Could not find corresponding jax.lax primitive for {scaled_func.__name__}") - return op + raise AttributeError(f"Could not find corresponding 'jax.lax' primitive for '{scaled_func.__name__}'.") + # Check as well it is a proper primitive! And not something else also in `jax.lax` + if not isinstance(prim, core.Primitive): + raise AttributeError(f"The object `{prim}` is not a proper JAX primitive for '{scaled_func.__name__}'.") + return prim def register_scaled_lax_op(scaled_func): """ - Registers a scaled function into the scaled_ops_registry by matching - the function name with pattern `scaled_{func_name}` to a function in the + Registers a scaled function/translation into the scaled_ops_registry by matching + the function name with pattern `scaled_{func_name}` to a primitive in the `jax.lax` namespace. - Example: `scaled_mul_p` is matched to `jax.lax.mul_p` + Example: `scaled_mul` is matched to `jax.lax.mul_p` """ lax_prim = _get_lax_prim(scaled_func) register_scaled_op(lax_prim, scaled_func) + # Always return the function in the case of decorator use. + return scaled_func def autoscale(fun): @@ -49,7 +66,7 @@ def wrapped(*args, **kwargs): return wrapped -def autoscale_jaxpr(jaxpr, consts, *args): +def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args): env: Dict[core.Var, ScaledArray] = {} def read(var): @@ -67,7 +84,7 @@ def write(var, val): invals = safe_map(read, eqn.invars) if eqn.primitive not in _scaled_ops_registry: raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") - outvals = _scaled_ops_registry[eqn.primitive](*invals) + outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] safe_map(write, eqn.outvars, outvals) diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 75e6f21..7611ff9 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -1,12 +1,39 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Optional, Sequence + +from jax import lax from jax_scaled_arithmetics import core -from jax_scaled_arithmetics.core import ScaledArray +from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape @core.register_scaled_lax_op -def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray: - return ScaledArray(A.data * B.data, A.scale * B.scale) +def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: Sequence[int]) -> ScaledArray: + return ScaledArray(lax.broadcast_in_dim(A.data, shape=shape, broadcast_dimensions=broadcast_dimensions), A.scale) -__all__ = ["scaled_mul_p"] +@core.register_scaled_lax_op +def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False) -> ScaledArray: + # NOTE: by default, no rescaling done before casting. + # Choice of adding an optional rescaling op before is up to the user (and which strategy to use). + # NOTE bis: scale not casted as well by default! + return ScaledArray(lax.convert_element_type(A.data, new_dtype=new_dtype), A.scale) + + +@core.register_scaled_lax_op +def scaled_slice( + A: ScaledArray, start_indices: Sequence[int], limit_indices: Sequence[int], strides: Optional[Sequence[int]] = None +) -> ScaledArray: + return ScaledArray( + lax.slice(A.data, start_indices=start_indices, limit_indices=limit_indices, strides=strides), A.scale + ) + + +@core.register_scaled_lax_op +def scaled_transpose(A: ScaledArray, permutation: Sequence[int]) -> ScaledArray: + return ScaledArray(lax.transpose(A.data, permutation=permutation), A.scale) + + +@core.register_scaled_lax_op +def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: + return ScaledArray(A.data * B.data, A.scale * B.scale) diff --git a/pyproject.toml b/pyproject.toml index d90e1e9..728e247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,9 @@ Website = "https://github.com/graphcore-research/jax-scaled-arithmetics/#readme" [project.optional-dependencies] test = ["pytest"] +[tool.setuptools] +packages = ["jax_scaled_arithmetics"] + [tool.pytest.ini_options] minversion = "6.0" addopts = ["-ra", "--showlocals", "--strict-config", "-p no:hypothesispytest"] diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index b906020..affe5ef 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -6,10 +6,14 @@ import numpy as np import numpy.testing as npt -from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array +from jax_scaled_arithmetics.core import ScaledArray, autoscale, register_scaled_op, scaled_array class AutoScaleInterpreterTests(chex.TestCase): + def test__register_scaled_op__error_if_already_registered(self): + with self.assertRaises(KeyError): + register_scaled_op(jax.lax.mul_p, lambda a, _: a) + @chex.variants(with_jit=True, without_jit=True) def test__scaled_identity_function(self): def func(x): @@ -34,7 +38,7 @@ def func(x): assert jaxpr.outvars[1].aval.shape == () @chex.variants(with_jit=True, without_jit=True) - def test__scaled_mul_function(self): + def test__scaled_mul__no_attributes(self): def func(x, y): return x * y @@ -48,3 +52,16 @@ def func(x, y): out = asfunc(x, y) assert isinstance(out, ScaledArray) npt.assert_array_almost_equal(out, expected) + + @chex.variants(with_jit=True, without_jit=True) + def test__scaled_convert_element_type__attributes_passing(self): + def func(x): + return jax.lax.convert_element_type(x, np.float16) + + # Autoscale + (optional) jitting. + asfunc = self.variant(autoscale(func)) + x = scaled_array([-4.0, 2.0], 0.5, dtype=np.float32) + out = asfunc(x) + assert isinstance(out, ScaledArray) + assert out.dtype == np.float16 + npt.assert_array_almost_equal(out, x) diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py new file mode 100644 index 0000000..cc72fcb --- /dev/null +++ b/tests/lax/test_scaled_ops.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import chex +import numpy as np +import numpy.testing as npt + +from jax_scaled_arithmetics.core import ScaledArray, scaled_array +from jax_scaled_arithmetics.lax import ( + scaled_broadcast_in_dim, + scaled_convert_element_type, + scaled_mul, + scaled_slice, + scaled_transpose, +) + + +class ScaledTranslationPrimitivesTests(chex.TestCase): + def test__scaled_broadcast_in_dim__proper_scaling(self): + x = scaled_array(np.random.rand(5), 2, dtype=np.float32) + z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,)) + assert isinstance(z, ScaledArray) + npt.assert_array_equal(z.scale, x.scale) + npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1))) + + def test__scaled_convert_element_type__proper_scaling(self): + x = scaled_array(np.random.rand(5), 2, dtype=np.float32) + z = scaled_convert_element_type(x, new_dtype=np.float16) + assert isinstance(z, ScaledArray) + npt.assert_array_equal(z.scale, x.scale) + npt.assert_array_almost_equal(z.data, x.data.astype(z.dtype)) + + def test__scaled_transpose__proper_scaling(self): + x = scaled_array(np.random.rand(3, 5), 2, dtype=np.float32) + z = scaled_transpose(x, (1, 0)) + assert isinstance(z, ScaledArray) + assert z.scale == x.scale + npt.assert_array_almost_equal(z.data, x.data.T) + + def test__scaled_slice__proper_scaling(self): + x = scaled_array(np.random.rand(5), 2, dtype=np.float32) + z = scaled_slice(x, (1,), (4,), (2,)) + assert isinstance(z, ScaledArray) + assert z.scale == x.scale + npt.assert_array_almost_equal(z.data, x.data[1:4:2]) + + def test__scaled_mul__proper_scaling(self): + x = scaled_array([-2.0, 2.0], 3, dtype=np.float32) + y = scaled_array([1.5, 1.5], 2, dtype=np.float32) + z = scaled_mul(x, y) + assert isinstance(z, ScaledArray) + assert z.scale == 6 + npt.assert_array_almost_equal(z, np.asarray(x) * np.asarray(y))