Skip to content

Commit

Permalink
Improve robustness of scaled translation with zero scale inputs. (#75)
Browse files Browse the repository at this point in the history
Adding `safe_div` and `safe_reciprocal` in order to avoid generating
`inf` and `nan` when input scales are zero.
  • Loading branch information
balancap authored Jan 9, 2024
1 parent 8949497 commit 23a74b1
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 13 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
register_scaled_op,
)
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up # noqa: F401
from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up, safe_div, safe_reciprocal # noqa: F401
25 changes: 25 additions & 0 deletions jax_scaled_arithmetics/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import IntEnum
from typing import Any, Dict

import jax
import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray

Expand Down Expand Up @@ -77,3 +79,26 @@ def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array:
elif mode == Pow2RoundMode.UP:
return pow2_round_up(val)
raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.")


def safe_div(lhs: Array, rhs: Array) -> Array:
"""Safe (scalar) div: if rhs is zero, returns zero."""
assert lhs.shape == ()
assert rhs.shape == ()
# assert lhs.dtype == rhs.dtype
# Numpy inputs => direct computation.
is_npy_inputs = isinstance(lhs, (np.number, np.ndarray)) and isinstance(rhs, (np.number, np.ndarray))
if is_npy_inputs:
return np.divide(lhs, rhs, out=np.array(0, dtype=rhs.dtype), where=rhs != 0)
# JAX general implementation.
return jax.lax.select(rhs == 0, rhs, jnp.divide(lhs, rhs))


def safe_reciprocal(val: Array) -> Array:
"""Safe (scalar) reciprocal: if val is zero, returns zero."""
assert val.shape == ()
# Numpy inputs => direct computation.
if isinstance(val, (np.number, np.ndarray)):
return np.reciprocal(val, out=np.array(0, dtype=val.dtype), where=val != 0)
# JAX general implementation.
return jax.lax.select(val == 0, val, jax.lax.reciprocal(val))
11 changes: 8 additions & 3 deletions jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
asarray,
is_static_one_scalar,
register_scaled_op,
safe_div,
safe_reciprocal,
)

set_scaling_p = core.Primitive("set_scaling_p")
Expand All @@ -24,6 +26,9 @@
In JAX Scaled Arithmetics/AutoScale mode, it will rebalance the data term to
return a ScaledArray semantically equivalent.
NOTE: there is specific corner case of passing zero to `set_scaling`. In this
situation, the tensor is assumed to be zeroed by the user.
"""


Expand All @@ -46,7 +51,7 @@ def set_scaling_impl(values: Array, scale: Array) -> Array:
# Automatic promotion should ensure we always get a scaled scalar here!
scale_value = asarray(scale)
# Rebalancing data tensor using the new scale.
data = values.data * (values.scale / scale_value).astype(values.dtype)
data = values.data * safe_div(values.scale, scale_value).astype(values.dtype)
return ScaledArray(data, scale_value)
# No scaled array => no-op.
return values
Expand Down Expand Up @@ -75,9 +80,9 @@ def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray:
scale_value = asarray(scale)
if not isinstance(values, ScaledArray):
# Simple case, with no pre-existing scale.
return ScaledArray(values / scale_value.astype(values.dtype), scale_value)
return ScaledArray(values * safe_reciprocal(scale_value.astype(values.dtype)), scale_value)
# Rebalancing data tensor using the new scale.
data = values.data * (values.scale / scale_value).astype(values.dtype)
data = values.data * safe_div(values.scale, scale_value).astype(values.dtype)
return ScaledArray(data, scale_value)


Expand Down
7 changes: 4 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
as_scaled_array,
get_scale_dtype,
is_static_zero,
safe_div,
)

from .base_scaling_primitives import scaled_set_scaling
Expand Down Expand Up @@ -89,7 +90,7 @@ def scaled_concatenate(operands: Sequence[ScaledArray], dimension: int) -> Scale
# TODO: explore alternative strategies?
outdtype = operands[0].dtype
scale_max = jnp.max(scales)
datas = [v.data * (v.scale / scale_max).astype(outdtype) for v in operands]
datas = [v.data * safe_div(v.scale, scale_max).astype(outdtype) for v in operands]
data_concat = lax.concatenate(datas, dimension=dimension)
return ScaledArray(data_concat, scale_max)

Expand Down Expand Up @@ -219,8 +220,8 @@ def scaled_minmax(prim: jax.core.Primitive, lhs: ScaledArray, rhs: ScaledArray)
output_scale = lax.max(lhs.scale, rhs.scale)
# TODO: isolate this "binary" rescale logic into separate function.
outdtype = jnp.promote_types(lhs.dtype, rhs.dtype)
lhs_rescale = (lhs.scale / output_scale).astype(outdtype)
rhs_rescale = (rhs.scale / output_scale).astype(outdtype)
lhs_rescale = safe_div(lhs.scale, output_scale).astype(outdtype)
rhs_rescale = safe_div(rhs.scale, output_scale).astype(outdtype)
output_data = prim.bind(lhs_rescale * lhs.data, rhs_rescale * rhs.data)
return ScaledArray(output_data, output_scale)

Expand Down
7 changes: 4 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_autoscale_config,
pow2_round,
register_scaled_op,
safe_div,
)

from .scaled_ops_common import check_scalar_scales, promote_scale_types
Expand All @@ -32,14 +33,14 @@ def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArra
# More stable than direct L2 norm, to avoid scale overflow.
ABscale_max = lax.max(A.scale, B.scale)
ABscale_min = lax.min(A.scale, B.scale)
ABscale_ratio = ABscale_min / ABscale_max
ABscale_ratio = safe_div(ABscale_min, ABscale_max)
output_scale = ABscale_max * lax.sqrt(1 + ABscale_ratio * ABscale_ratio)
# Transform back to power-of-2
output_scale = pow2_round(output_scale, pow2_rounding_mode)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Brescale = (B.scale / output_scale).astype(outdtype)
Arescale = safe_div(A.scale, output_scale).astype(outdtype)
Brescale = safe_div(B.scale, output_scale).astype(outdtype)
# check correct type output if mismatch between data and scale precision
output_data = binary_op(Arescale * A.data, Brescale * B.data)
return ScaledArray(output_data, output_scale)
Expand Down
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/ops/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray:
data_sq = jax.lax.integer_pow(data.astype(np.float32), 2)
axes = tuple(range(data.ndim))
# Get L2 norm + pow2 rounding.
norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes)) / data.size
norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size)
norm = pow2_round(norm.astype(scale.dtype))
# Rebalancing based on norm.
return rebalance(arr, norm)
Expand Down
40 changes: 38 additions & 2 deletions tests/core/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import pow2_round_down, pow2_round_up
from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa
from jax_scaled_arithmetics.core import Array, pow2_round_down, pow2_round_up
from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa, safe_div, safe_reciprocal


class Pow2RoundingUtilTests(chex.TestCase):
Expand Down Expand Up @@ -54,3 +55,38 @@ def test__pow2_round_up__proper_rounding__multi_dtypes(self, val_exp, dtype):
assert pow2_val.dtype == val.dtype
assert type(pow2_val) in {type(val), np.ndarray}
npt.assert_equal(pow2_val, exp)


class SafeDivOpTests(chex.TestCase):
@parameterized.parameters(
{"lhs": np.float16(0), "rhs": np.float16(0)},
{"lhs": np.float32(0), "rhs": np.float32(0)},
{"lhs": np.float16(2), "rhs": np.float16(0)},
{"lhs": np.float32(4), "rhs": np.float32(0)},
)
def test__safe_div__zero_div__numpy_inputs(self, lhs, rhs):
out = safe_div(lhs, rhs)
assert isinstance(out, (np.number, np.ndarray))
assert out.dtype == lhs.dtype
npt.assert_equal(out, 0)

@parameterized.parameters(
{"lhs": np.float16(0), "rhs": jnp.float16(0)},
{"lhs": jnp.float32(0), "rhs": np.float32(0)},
{"lhs": jnp.float16(2), "rhs": np.float16(0)},
{"lhs": np.float32(4), "rhs": jnp.float32(0)},
)
def test__safe_div__zero_div__jax_inputs(self, lhs, rhs):
out = safe_div(lhs, rhs)
assert isinstance(out, Array)
assert out.dtype == lhs.dtype
npt.assert_almost_equal(out, 0)

@parameterized.parameters(
{"val": np.float16(0)},
{"val": jnp.float16(0)},
)
def test__safe_reciprocal__zero_div(self, val):
out = safe_reciprocal(val)
assert out.dtype == val.dtype
npt.assert_almost_equal(out, 0)
16 changes: 16 additions & 0 deletions tests/lax/test_base_scaling_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ def test__set_scaling_primitive__scaled_array__eager_mode(self, npapi):
npt.assert_equal(output.scale, npapi.float16(4))
npt.assert_array_equal(output, values)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
{"arr": np.array([-1.0, 2.0], dtype=np.float32)},
{"arr": scaled_array([-1.0, 2.0], 1.0, dtype=np.float16)},
{"arr": scaled_array([-1.0, 2.0], 0.0, dtype=np.float32)},
)
def test__set_scaling_primitive__zero_scaling(self, arr):
def fn(arr, scale):
return set_scaling(arr, scale)

scale = np.array(0, dtype=arr.dtype)
out = self.variant(autoscale(fn))(arr, scale)
assert isinstance(out, ScaledArray)
npt.assert_array_almost_equal(out.scale, 0)
npt.assert_array_almost_equal(out.data, 0)

@chex.variants(with_jit=True, without_jit=True)
def test__set_scaling_primitive__proper_result_without_autoscale(self):
def fn(arr, scale):
Expand Down
9 changes: 9 additions & 0 deletions tests/lax/test_scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def test__scaled_concatenate__proper_scaling(self):
npt.assert_array_equal(z.scale, y.scale)
npt.assert_array_almost_equal(z, lax.concatenate([np.asarray(x), np.asarray(y)], dimension=0))

def test__scaled_concatenate__zero_input_scales(self):
x = scaled_array(self.rs.rand(2, 3), 0.0, dtype=np.float16)
y = scaled_array(self.rs.rand(5, 3), 0.0, dtype=np.float16)
z = scaled_concatenate([x, y], dimension=0)
assert isinstance(z, ScaledArray)
assert z.dtype == x.dtype
npt.assert_array_equal(z.scale, 0)
npt.assert_array_almost_equal(z, lax.concatenate([np.asarray(x), np.asarray(y)], dimension=0))

def test__scaled_convert_element_type__proper_scaling(self):
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
z = scaled_convert_element_type(x, new_dtype=np.float16)
Expand Down
22 changes: 22 additions & 0 deletions tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,28 @@ def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype, sdtyp
assert z.scale.dtype == sdtype
npt.assert_array_almost_equal(z, expected_z, decimal=4)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
prim=[lax.add_p, lax.sub_p, lax.mul_p, lax.min_p, lax.max_p],
dtype=[np.float16, np.float32],
sdtype=[np.float16, np.float32],
)
def test__scaled_binary_op__proper_zero_scale_handling(self, prim, dtype, sdtype):
scaled_op, _ = find_registered_scaled_op(prim)
# NOTE: direct construction to avoid weirdity between NumPy array and scalar!
x = ScaledArray(np.array([-1.0, 2.0], dtype), sdtype(0.0))
y = ScaledArray(np.array([1.5, 4.5], dtype), sdtype(0.0))
# Ensure scale factor has the right dtype.
assert x.scale.dtype == sdtype
assert y.scale.dtype == sdtype

z = self.variant(scaled_op)(x, y)
expected_z = prim.bind(np.asarray(x), np.asarray(y))

assert z.dtype == x.dtype
assert z.scale.dtype == sdtype
npt.assert_array_almost_equal(z, expected_z, decimal=4)

@parameterized.parameters(
{"prim": lax.add_p},
{"prim": lax.sub_p},
Expand Down

0 comments on commit 23a74b1

Please sign in to comment.