From 23a74b142dc5bbd680de5ce81171282e80149064 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 9 Jan 2024 15:33:13 +0000 Subject: [PATCH] Improve robustness of scaled translation with zero scale inputs. (#75) Adding `safe_div` and `safe_reciprocal` in order to avoid generating `inf` and `nan` when input scales are zero. --- jax_scaled_arithmetics/core/__init__.py | 2 +- jax_scaled_arithmetics/core/utils.py | 25 ++++++++++++ .../lax/base_scaling_primitives.py | 11 +++-- .../lax/scaled_ops_common.py | 7 ++-- jax_scaled_arithmetics/lax/scaled_ops_l2.py | 7 ++-- jax_scaled_arithmetics/ops/rescaling.py | 2 +- tests/core/test_utils.py | 40 ++++++++++++++++++- tests/lax/test_base_scaling_primitives.py | 16 ++++++++ tests/lax/test_scaled_ops_common.py | 9 +++++ tests/lax/test_scaled_ops_l2.py | 22 ++++++++++ 10 files changed, 128 insertions(+), 13 deletions(-) diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index f99bf1f..fc1277f 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -23,4 +23,4 @@ register_scaled_op, ) from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401 -from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up # noqa: F401 +from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up, safe_div, safe_reciprocal # noqa: F401 diff --git a/jax_scaled_arithmetics/core/utils.py b/jax_scaled_arithmetics/core/utils.py index 32552db..9e8ef96 100644 --- a/jax_scaled_arithmetics/core/utils.py +++ b/jax_scaled_arithmetics/core/utils.py @@ -2,6 +2,8 @@ from enum import IntEnum from typing import Any, Dict +import jax +import jax.numpy as jnp import numpy as np from numpy.typing import NDArray @@ -77,3 +79,26 @@ def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array: elif mode == Pow2RoundMode.UP: return pow2_round_up(val) raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.") + + +def safe_div(lhs: Array, rhs: Array) -> Array: + """Safe (scalar) div: if rhs is zero, returns zero.""" + assert lhs.shape == () + assert rhs.shape == () + # assert lhs.dtype == rhs.dtype + # Numpy inputs => direct computation. + is_npy_inputs = isinstance(lhs, (np.number, np.ndarray)) and isinstance(rhs, (np.number, np.ndarray)) + if is_npy_inputs: + return np.divide(lhs, rhs, out=np.array(0, dtype=rhs.dtype), where=rhs != 0) + # JAX general implementation. + return jax.lax.select(rhs == 0, rhs, jnp.divide(lhs, rhs)) + + +def safe_reciprocal(val: Array) -> Array: + """Safe (scalar) reciprocal: if val is zero, returns zero.""" + assert val.shape == () + # Numpy inputs => direct computation. + if isinstance(val, (np.number, np.ndarray)): + return np.reciprocal(val, out=np.array(0, dtype=val.dtype), where=val != 0) + # JAX general implementation. + return jax.lax.select(val == 0, val, jax.lax.reciprocal(val)) diff --git a/jax_scaled_arithmetics/lax/base_scaling_primitives.py b/jax_scaled_arithmetics/lax/base_scaling_primitives.py index ace19b4..8aa2466 100644 --- a/jax_scaled_arithmetics/lax/base_scaling_primitives.py +++ b/jax_scaled_arithmetics/lax/base_scaling_primitives.py @@ -14,6 +14,8 @@ asarray, is_static_one_scalar, register_scaled_op, + safe_div, + safe_reciprocal, ) set_scaling_p = core.Primitive("set_scaling_p") @@ -24,6 +26,9 @@ In JAX Scaled Arithmetics/AutoScale mode, it will rebalance the data term to return a ScaledArray semantically equivalent. + +NOTE: there is specific corner case of passing zero to `set_scaling`. In this +situation, the tensor is assumed to be zeroed by the user. """ @@ -46,7 +51,7 @@ def set_scaling_impl(values: Array, scale: Array) -> Array: # Automatic promotion should ensure we always get a scaled scalar here! scale_value = asarray(scale) # Rebalancing data tensor using the new scale. - data = values.data * (values.scale / scale_value).astype(values.dtype) + data = values.data * safe_div(values.scale, scale_value).astype(values.dtype) return ScaledArray(data, scale_value) # No scaled array => no-op. return values @@ -75,9 +80,9 @@ def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray: scale_value = asarray(scale) if not isinstance(values, ScaledArray): # Simple case, with no pre-existing scale. - return ScaledArray(values / scale_value.astype(values.dtype), scale_value) + return ScaledArray(values * safe_reciprocal(scale_value.astype(values.dtype)), scale_value) # Rebalancing data tensor using the new scale. - data = values.data * (values.scale / scale_value).astype(values.dtype) + data = values.data * safe_div(values.scale, scale_value).astype(values.dtype) return ScaledArray(data, scale_value) diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scaled_arithmetics/lax/scaled_ops_common.py index 1dde3c8..91f0f5f 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_common.py @@ -16,6 +16,7 @@ as_scaled_array, get_scale_dtype, is_static_zero, + safe_div, ) from .base_scaling_primitives import scaled_set_scaling @@ -89,7 +90,7 @@ def scaled_concatenate(operands: Sequence[ScaledArray], dimension: int) -> Scale # TODO: explore alternative strategies? outdtype = operands[0].dtype scale_max = jnp.max(scales) - datas = [v.data * (v.scale / scale_max).astype(outdtype) for v in operands] + datas = [v.data * safe_div(v.scale, scale_max).astype(outdtype) for v in operands] data_concat = lax.concatenate(datas, dimension=dimension) return ScaledArray(data_concat, scale_max) @@ -219,8 +220,8 @@ def scaled_minmax(prim: jax.core.Primitive, lhs: ScaledArray, rhs: ScaledArray) output_scale = lax.max(lhs.scale, rhs.scale) # TODO: isolate this "binary" rescale logic into separate function. outdtype = jnp.promote_types(lhs.dtype, rhs.dtype) - lhs_rescale = (lhs.scale / output_scale).astype(outdtype) - rhs_rescale = (rhs.scale / output_scale).astype(outdtype) + lhs_rescale = safe_div(lhs.scale, output_scale).astype(outdtype) + rhs_rescale = safe_div(rhs.scale, output_scale).astype(outdtype) output_data = prim.bind(lhs_rescale * lhs.data, rhs_rescale * rhs.data) return ScaledArray(output_data, output_scale) diff --git a/jax_scaled_arithmetics/lax/scaled_ops_l2.py b/jax_scaled_arithmetics/lax/scaled_ops_l2.py index 2af3c90..c476a16 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_l2.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_l2.py @@ -14,6 +14,7 @@ get_autoscale_config, pow2_round, register_scaled_op, + safe_div, ) from .scaled_ops_common import check_scalar_scales, promote_scale_types @@ -32,14 +33,14 @@ def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArra # More stable than direct L2 norm, to avoid scale overflow. ABscale_max = lax.max(A.scale, B.scale) ABscale_min = lax.min(A.scale, B.scale) - ABscale_ratio = ABscale_min / ABscale_max + ABscale_ratio = safe_div(ABscale_min, ABscale_max) output_scale = ABscale_max * lax.sqrt(1 + ABscale_ratio * ABscale_ratio) # Transform back to power-of-2 output_scale = pow2_round(output_scale, pow2_rounding_mode) # Output dtype => promotion of A and B dtypes. outdtype = jnp.promote_types(A.dtype, B.dtype) - Arescale = (A.scale / output_scale).astype(outdtype) - Brescale = (B.scale / output_scale).astype(outdtype) + Arescale = safe_div(A.scale, output_scale).astype(outdtype) + Brescale = safe_div(B.scale, output_scale).astype(outdtype) # check correct type output if mismatch between data and scale precision output_data = binary_op(Arescale * A.data, Brescale * B.data) return ScaledArray(output_data, output_scale) diff --git a/jax_scaled_arithmetics/ops/rescaling.py b/jax_scaled_arithmetics/ops/rescaling.py index 1d3dd26..cd7687b 100644 --- a/jax_scaled_arithmetics/ops/rescaling.py +++ b/jax_scaled_arithmetics/ops/rescaling.py @@ -78,7 +78,7 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray: data_sq = jax.lax.integer_pow(data.astype(np.float32), 2) axes = tuple(range(data.ndim)) # Get L2 norm + pow2 rounding. - norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes)) / data.size + norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size) norm = pow2_round(norm.astype(scale.dtype)) # Rebalancing based on norm. return rebalance(arr, norm) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index 8a3932b..e174a34 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -1,11 +1,12 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. import chex +import jax.numpy as jnp import numpy as np import numpy.testing as npt from absl.testing import parameterized -from jax_scaled_arithmetics.core import pow2_round_down, pow2_round_up -from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa +from jax_scaled_arithmetics.core import Array, pow2_round_down, pow2_round_up +from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa, safe_div, safe_reciprocal class Pow2RoundingUtilTests(chex.TestCase): @@ -54,3 +55,38 @@ def test__pow2_round_up__proper_rounding__multi_dtypes(self, val_exp, dtype): assert pow2_val.dtype == val.dtype assert type(pow2_val) in {type(val), np.ndarray} npt.assert_equal(pow2_val, exp) + + +class SafeDivOpTests(chex.TestCase): + @parameterized.parameters( + {"lhs": np.float16(0), "rhs": np.float16(0)}, + {"lhs": np.float32(0), "rhs": np.float32(0)}, + {"lhs": np.float16(2), "rhs": np.float16(0)}, + {"lhs": np.float32(4), "rhs": np.float32(0)}, + ) + def test__safe_div__zero_div__numpy_inputs(self, lhs, rhs): + out = safe_div(lhs, rhs) + assert isinstance(out, (np.number, np.ndarray)) + assert out.dtype == lhs.dtype + npt.assert_equal(out, 0) + + @parameterized.parameters( + {"lhs": np.float16(0), "rhs": jnp.float16(0)}, + {"lhs": jnp.float32(0), "rhs": np.float32(0)}, + {"lhs": jnp.float16(2), "rhs": np.float16(0)}, + {"lhs": np.float32(4), "rhs": jnp.float32(0)}, + ) + def test__safe_div__zero_div__jax_inputs(self, lhs, rhs): + out = safe_div(lhs, rhs) + assert isinstance(out, Array) + assert out.dtype == lhs.dtype + npt.assert_almost_equal(out, 0) + + @parameterized.parameters( + {"val": np.float16(0)}, + {"val": jnp.float16(0)}, + ) + def test__safe_reciprocal__zero_div(self, val): + out = safe_reciprocal(val) + assert out.dtype == val.dtype + npt.assert_almost_equal(out, 0) diff --git a/tests/lax/test_base_scaling_primitives.py b/tests/lax/test_base_scaling_primitives.py index a13bda1..af8e401 100644 --- a/tests/lax/test_base_scaling_primitives.py +++ b/tests/lax/test_base_scaling_primitives.py @@ -31,6 +31,22 @@ def test__set_scaling_primitive__scaled_array__eager_mode(self, npapi): npt.assert_equal(output.scale, npapi.float16(4)) npt.assert_array_equal(output, values) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.parameters( + {"arr": np.array([-1.0, 2.0], dtype=np.float32)}, + {"arr": scaled_array([-1.0, 2.0], 1.0, dtype=np.float16)}, + {"arr": scaled_array([-1.0, 2.0], 0.0, dtype=np.float32)}, + ) + def test__set_scaling_primitive__zero_scaling(self, arr): + def fn(arr, scale): + return set_scaling(arr, scale) + + scale = np.array(0, dtype=arr.dtype) + out = self.variant(autoscale(fn))(arr, scale) + assert isinstance(out, ScaledArray) + npt.assert_array_almost_equal(out.scale, 0) + npt.assert_array_almost_equal(out.data, 0) + @chex.variants(with_jit=True, without_jit=True) def test__set_scaling_primitive__proper_result_without_autoscale(self): def fn(arr, scale): diff --git a/tests/lax/test_scaled_ops_common.py b/tests/lax/test_scaled_ops_common.py index ec83711..aeabb90 100644 --- a/tests/lax/test_scaled_ops_common.py +++ b/tests/lax/test_scaled_ops_common.py @@ -81,6 +81,15 @@ def test__scaled_concatenate__proper_scaling(self): npt.assert_array_equal(z.scale, y.scale) npt.assert_array_almost_equal(z, lax.concatenate([np.asarray(x), np.asarray(y)], dimension=0)) + def test__scaled_concatenate__zero_input_scales(self): + x = scaled_array(self.rs.rand(2, 3), 0.0, dtype=np.float16) + y = scaled_array(self.rs.rand(5, 3), 0.0, dtype=np.float16) + z = scaled_concatenate([x, y], dimension=0) + assert isinstance(z, ScaledArray) + assert z.dtype == x.dtype + npt.assert_array_equal(z.scale, 0) + npt.assert_array_almost_equal(z, lax.concatenate([np.asarray(x), np.asarray(y)], dimension=0)) + def test__scaled_convert_element_type__proper_scaling(self): x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) z = scaled_convert_element_type(x, new_dtype=np.float16) diff --git a/tests/lax/test_scaled_ops_l2.py b/tests/lax/test_scaled_ops_l2.py index 7641b98..3d9bf90 100644 --- a/tests/lax/test_scaled_ops_l2.py +++ b/tests/lax/test_scaled_ops_l2.py @@ -92,6 +92,28 @@ def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype, sdtyp assert z.scale.dtype == sdtype npt.assert_array_almost_equal(z, expected_z, decimal=4) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + prim=[lax.add_p, lax.sub_p, lax.mul_p, lax.min_p, lax.max_p], + dtype=[np.float16, np.float32], + sdtype=[np.float16, np.float32], + ) + def test__scaled_binary_op__proper_zero_scale_handling(self, prim, dtype, sdtype): + scaled_op, _ = find_registered_scaled_op(prim) + # NOTE: direct construction to avoid weirdity between NumPy array and scalar! + x = ScaledArray(np.array([-1.0, 2.0], dtype), sdtype(0.0)) + y = ScaledArray(np.array([1.5, 4.5], dtype), sdtype(0.0)) + # Ensure scale factor has the right dtype. + assert x.scale.dtype == sdtype + assert y.scale.dtype == sdtype + + z = self.variant(scaled_op)(x, y) + expected_z = prim.bind(np.asarray(x), np.asarray(y)) + + assert z.dtype == x.dtype + assert z.scale.dtype == sdtype + npt.assert_array_almost_equal(z, expected_z, decimal=4) + @parameterized.parameters( {"prim": lax.add_p}, {"prim": lax.sub_p},