From 07df3b6ad5837dad6e5554bc85d5c782d50818c8 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 16 Nov 2023 16:33:39 +0000 Subject: [PATCH] Implement `set_scaling` and `stop_scaling` JAX primitives. * `set_scaling`: Set the scaling of a tensor, transforming into `ScaledArray` in `autoscale` mode. * `stop_scaling`: Stop scale propagation of a tensor, transforming back into a JAX array. Both operations are no-op identity operations in normal JAX mode. Note: as pointed by @DouglasOrr, these primitives could also be formalized as casting operations, where ScaledDtypes are properly defined. To be clarified whether it may be a better setting for the JAX implementation! --- jax_scaled_arithmetics/core/interpreters.py | 4 +- jax_scaled_arithmetics/lax/__init__.py | 1 + .../lax/base_scaling_primitives.py | 115 ++++++++++++++++++ tests/lax/test_base_scaling_primitives.py | 68 +++++++++++ tests/lax/test_scaled_ops.py | 21 ++-- 5 files changed, 200 insertions(+), 9 deletions(-) create mode 100644 jax_scaled_arithmetics/lax/base_scaling_primitives.py create mode 100644 tests/lax/test_base_scaling_primitives.py diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 91ef749..6e46c13 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -113,7 +113,9 @@ def to_scaled_array(val): # Primitive is supported by `autoscale`? if eqn.primitive not in _scaled_ops_registry: - raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") + raise NotImplementedError( + f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." + ) outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index 65b52cc..148fe69 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from .base_scaling_primitives import set_scaling, set_scaling_p, stop_scaling, stop_scaling_p # noqa: F401 from .scaled_ops import * # noqa: F401, F403 diff --git a/jax_scaled_arithmetics/lax/base_scaling_primitives.py b/jax_scaled_arithmetics/lax/base_scaling_primitives.py new file mode 100644 index 0000000..81f51d8 --- /dev/null +++ b/jax_scaled_arithmetics/lax/base_scaling_primitives.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Optional, Sequence, Union + +import jax +from jax import core +from jax.interpreters import mlir +from jax.interpreters.mlir import LoweringRuleContext, ir + +from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, register_scaled_op + +set_scaling_p = core.Primitive("set_scaling_p") +"""`set_scaling` JAX primitive. + +In standard JAX, this is just an identity operation, ignoring the `scale` +input, just returning unchanged the `data` component. + +In JAX Scaled Arithmetics/AutoScale mode, it will rebalance the data term to +return a ScaledArray semantically equivalent. +""" + + +def set_scaling(values: jax.Array, scale: jax.Array) -> jax.Array: + """`set_scaling` primitive call method.""" + return set_scaling_p.bind(values, scale) + + +def set_scaling_impl(values: jax.Array, scale: jax.Array) -> jax.Array: + return values + + +def set_scaling_abstract_eval(values: core.ShapedArray, scale: core.ShapedArray) -> core.ShapedArray: + return values + + +def set_scaling_mlir_lowering( + ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]] +) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: + # Just forwarding `values` term, ignoring the `scale`. + return (args[0],) + + +def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray: + """Scaled `set_scaling` implementation: rebalancing the data using the new scale value.""" + assert isinstance(values, ScaledArray) + assert isinstance(scale, ScaledArray) + assert scale.shape == () + # TODO/FIXME: handle not scaled inputs!!! + scale_value = scale.to_array() + # Rebalancing data tensor using the new scale. + data = values.data * (values.scale / scale_value) + return ScaledArray(data, scale_value) + + +# Register as standard JAX primitive +set_scaling_p.multiple_results = False +set_scaling_p.def_abstract_eval(set_scaling_abstract_eval) +set_scaling_p.def_impl(set_scaling_impl) +mlir.register_lowering(set_scaling_p, set_scaling_mlir_lowering) +# Register "scaled" translation. +register_scaled_op(set_scaling_p, scaled_set_scaling) + + +stop_scaling_p = core.Primitive("stop_scaling_p") +"""`stop_scaling` JAX primitive. + +In standard JAX, this is just an identity operation (with optional casting). + +In JAX Scaled Arithmetics/AutoScale mode, it will return the value tensor, +with optional casting. + +Similar in principle to `jax.lax.stop_gradient` +""" + + +def stop_scaling(values: jax.Array, dtype: Optional[DTypeLike] = None) -> jax.Array: + """`stop_scaling` primitive call method.""" + return stop_scaling_p.bind(values, dtype=dtype) + + +def stop_scaling_impl(values: jax.Array, dtype: Optional[DTypeLike]) -> jax.Array: + if dtype is not None: + values = values.astype(dtype) + return values + + +def stop_scaling_abstract_eval(values: core.ShapedArray, dtype: Optional[DTypeLike]) -> core.ShapedArray: + return values.update(dtype=dtype) + + +def stop_scaling_mlir_lowering( + ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params +) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: + dtype = params.get("dtype", None) + if dtype is not None: + # TODO: caching of the MLIR lowered function? + stop_scaling_mlir_fn = mlir.lower_fun(lambda x: x.astype(dtype), multiple_results=False) + return stop_scaling_mlir_fn(ctx, *args) + # By default: forward tensor. + return (args[0],) + + +def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None) -> jax.Array: + """Scaled `stop_scaling` implementation: returning tensor values (with optional cast).""" + assert isinstance(values, ScaledArray) + # TODO/FIXME: how to handle not scaled input. + return values.to_array(dtype=dtype) + + +# Register as standard JAX primitive +stop_scaling_p.multiple_results = False +stop_scaling_p.def_abstract_eval(stop_scaling_abstract_eval) +stop_scaling_p.def_impl(stop_scaling_impl) +mlir.register_lowering(stop_scaling_p, stop_scaling_mlir_lowering) +# Register "scaled" translation. +register_scaled_op(stop_scaling_p, scaled_stop_scaling) diff --git a/tests/lax/test_base_scaling_primitives.py b/tests/lax/test_base_scaling_primitives.py new file mode 100644 index 0000000..fa5f188 --- /dev/null +++ b/tests/lax/test_base_scaling_primitives.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import chex +import jax +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt + +from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array +from jax_scaled_arithmetics.lax import set_scaling, stop_scaling + + +class SetScalingPrimitiveTests(chex.TestCase): + @chex.variants(with_jit=True, without_jit=True) + def test__set_scaling_primitive__proper_result_without_autoscale(self): + def fn(arr, scale): + return set_scaling(arr, scale) + + fn = self.variant(fn) + arr = jnp.array([2, 3], dtype=np.float32) + scale = jnp.array(4, dtype=np.float32) + out = fn(arr, scale) + npt.assert_array_equal(out, arr) + + @chex.variants(with_jit=True, without_jit=True) + def test__set_scaling_primitive__proper_result_with_autoscale(self): + def fn(arr, scale): + return set_scaling(arr, scale) + + fn = self.variant(autoscale(fn)) + arr = scaled_array([-1.0, 2.0], 1.0, dtype=np.float32) + # TODO: support scalar here! + scale = scaled_array(1.0, 4.0, dtype=np.float32) + out = fn(arr, scale) + # Unchanged output tensor! + assert isinstance(out, ScaledArray) + npt.assert_array_equal(out.scale, scale) + npt.assert_array_equal(out, arr) + + +class StopScalingPrimitiveTests(chex.TestCase): + @chex.variants(with_jit=True, without_jit=True) + def test__stop_scaling_primitive__proper_result_without_autoscale(self): + def fn(arr): + # Testing both variants. + return stop_scaling(arr), stop_scaling(arr, dtype=np.float16) + + arr = jnp.array([2, 3], dtype=np.float32) + out0, out1 = self.variant(fn)(arr) + assert out0.dtype == arr.dtype + assert out1.dtype == np.float16 + npt.assert_array_equal(out0, arr) + npt.assert_array_almost_equal(out1, arr) + + @chex.variants(with_jit=True, without_jit=True) + def test__stop_scaling_primitive__proper_result_with_autoscale(self): + def fn(arr): + # Testing both variants. + return stop_scaling(arr), stop_scaling(arr, dtype=np.float16) + + fn = self.variant(autoscale(fn)) + arr = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32) + out0, out1 = fn(arr) + assert isinstance(out0, jax.Array) + assert isinstance(out1, jax.Array) + assert out0.dtype == arr.dtype + assert out1.dtype == np.float16 + npt.assert_array_equal(out0, arr) + npt.assert_array_almost_equal(out1, arr) diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index e2ad2f8..f733f05 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -18,37 +18,42 @@ class ScaledTranslationPrimitivesTests(chex.TestCase): + def setUp(self): + super().setUp() + # Use random state for reproducibility! + self.rs = np.random.RandomState(42) + def test__scaled_broadcast_in_dim__proper_scaling(self): - x = scaled_array(np.random.rand(5), 2, dtype=np.float32) + x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,)) assert isinstance(z, ScaledArray) npt.assert_array_equal(z.scale, x.scale) npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1))) def test__scaled_concatenate__proper_scaling(self): - x = scaled_array(np.random.rand(2, 3), 0.5, dtype=np.float32) - y = scaled_array(np.random.rand(5, 3), 2, dtype=np.float32) + x = scaled_array(self.rs.rand(2, 3), 0.5, dtype=np.float32) + y = scaled_array(self.rs.rand(5, 3), 2, dtype=np.float32) z = scaled_concatenate([x, y], dimension=0) assert isinstance(z, ScaledArray) npt.assert_array_equal(z.scale, y.scale) npt.assert_array_almost_equal(z, np.concatenate([x, y], axis=0)) def test__scaled_convert_element_type__proper_scaling(self): - x = scaled_array(np.random.rand(5), 2, dtype=np.float32) + x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) z = scaled_convert_element_type(x, new_dtype=np.float16) assert isinstance(z, ScaledArray) npt.assert_array_equal(z.scale, x.scale) npt.assert_array_almost_equal(z.data, x.data.astype(z.dtype)) def test__scaled_transpose__proper_scaling(self): - x = scaled_array(np.random.rand(3, 5), 2, dtype=np.float32) + x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32) z = scaled_transpose(x, (1, 0)) assert isinstance(z, ScaledArray) assert z.scale == x.scale npt.assert_array_almost_equal(z.data, x.data.T) def test__scaled_slice__proper_scaling(self): - x = scaled_array(np.random.rand(5), 2, dtype=np.float32) + x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) z = scaled_slice(x, (1,), (4,), (2,)) assert isinstance(z, ScaledArray) assert z.scale == x.scale @@ -81,8 +86,8 @@ def test__scaled_sub__proper_scaling(self): npt.assert_array_almost_equal(z, np.asarray(x) - np.asarray(y)) def test__scaled_dot_general__proper_scaling(self): - lhs = scaled_array(np.random.rand(3, 5), 2.0, dtype=np.float32) - rhs = scaled_array(np.random.rand(5, 2), 3.0, dtype=np.float32) + lhs = scaled_array(self.rs.rand(3, 5), 2.0, dtype=np.float32) + rhs = scaled_array(self.rs.rand(5, 2), 3.0, dtype=np.float32) out = scaled_dot_general(lhs, rhs, (((1,), (0,)), ((), ()))) assert isinstance(out, ScaledArray) assert out.dtype == lhs.dtype