Skip to content

Commit

Permalink
Use power-of-two scaling in autoscale scaled translation ops rules.
Browse files Browse the repository at this point in the history
As shown in #60 issue, propagating non power-of-two scaling factors can decrease training accuracy in low precision (typically in FP16).
The additional rescaling operations will introduce non-negligible floating point accumulated errors.

Ths PR is adding the option to round the scale to a power-of-two in scaled translation. Supporting at the moment only rounding up and down. The rounding mode
can be modified in the config dataclass `AutoScaleConfig`.
  • Loading branch information
balancap committed Jan 2, 2024
1 parent 591645c commit 2c02f66
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 12 deletions.
3 changes: 3 additions & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
)
from .debug import debug_callback # noqa: F401
from .interpreters import ( # noqa: F401
AutoScaleConfig,
ScaledPrimitiveType,
autoscale,
find_registered_scaled_op,
get_autoscale_config,
register_scaled_lax_op,
register_scaled_op,
)
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up # noqa: F401
30 changes: 30 additions & 0 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from dataclasses import dataclass
from enum import IntEnum
from functools import partial, wraps
from typing import Any, Dict, Sequence, Tuple
Expand All @@ -15,6 +16,35 @@
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode


@dataclass(frozen=True)
class AutoScaleConfig:
"""AutoScale configuration/parameters when tracing a graph.
NOTE: this config can be locally changed using a Python context manager:
`with AutoScaleConfig(...):`
"""

rounding_mode: Pow2RoundMode = Pow2RoundMode.DOWN

def __enter__(self):
global _autoscale_config_stack
_autoscale_config_stack.append(self)

def __exit__(self, exc_type, exc_val, exc_tb):
global _autoscale_config_stack
_autoscale_config_stack.pop()


# AutoScale config stack.
_autoscale_config_stack = [AutoScaleConfig()]


def get_autoscale_config() -> AutoScaleConfig:
"""Get current/local autoscale config."""
return _autoscale_config_stack[-1]


class ScaledPrimitiveType(IntEnum):
Expand Down
9 changes: 6 additions & 3 deletions jax_scaled_arithmetics/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
def get_numpy_api(val: Any) -> Any:
"""Get the Numpy API corresponding to an array.
Using the NumPy API whenever possible when tracing a JAX graph
allows for simple constant folding optimization.
JAX or classic Numpy supported.
"""
if isinstance(val, jax.Array):
return jnp
elif isinstance(val, (np.ndarray, np.number)):
if isinstance(val, (np.ndarray, np.number)):
return np
if isinstance(val, ArrayTypes):
return jnp
raise NotImplementedError(f"Unsupported input type '{type(val)}'. No matching Numpy API.")
65 changes: 65 additions & 0 deletions jax_scaled_arithmetics/core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from enum import IntEnum
from typing import Any, Dict

import numpy as np
from numpy.typing import NDArray

from .typing import Array, get_numpy_api

# Exponent bits masking.
_exponent_bits_mask: Dict[Any, NDArray[Any]] = {
np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view(
np.int16
),
np.dtype(np.float32): np.packbits(
np.array(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
dtype=np.uint8,
)
).view(np.int32),
np.dtype(np.float64): np.array(np.inf, np.float64).view(np.int64),
}
"""Exponents bit masking: explicit bitmask to keep only exponent bits in floating point values.
NOTE: normally should also correspond to `np.inf` value for FP16 and FP32.
"""


class Pow2RoundMode(IntEnum):
"""Power-of-two supported rounded mode."""

NONE = 0
DOWN = 1
UP = 2
STOCHASTIC = 3


def pow2_round_down(val: Array) -> Array:
"""Round down to the closest power of 2."""
np_api = get_numpy_api(val)
exponent_mask = _exponent_bits_mask[val.dtype]
intdtype = exponent_mask.dtype
pow2_val = np_api.bitwise_and(val.view(intdtype), exponent_mask).view(val.dtype).reshape(val.shape)
return pow2_val


def pow2_round_up(val: Array) -> Array:
"""Round up to the closest power of 2.
NOTE: may overflow to inf.
"""
# FIXME: rounding when already a power of 2.
# Should do additional masking to check that.
pow2_val = pow2_round_down(val) * np.array(2, dtype=val.dtype)
return pow2_val


def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array:
"""Power-of-two rounding."""
if mode == Pow2RoundMode.NONE:
return val
elif mode == Pow2RoundMode.DOWN:
return pow2_round_down(val)
elif mode == Pow2RoundMode.UP:
return pow2_round_up(val)
raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.")
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def scaled_reduce_precision(A: ScaledArray, exponent_bits: int, mantissa_bits: i
def scaled_concatenate(operands: Sequence[ScaledArray], dimension: int) -> ScaledArray:
# TODO: inputs checking (dtype and cie).
scales = jnp.array([v.scale for v in operands])
# Max rescaling of the collection of operands.
# Max rescaling of the collection of operands. Preserving pow2 scaling.
# TODO: explore alternative strategies?
outdtype = operands[0].dtype
scale_max = jnp.max(scales)
Expand Down
24 changes: 21 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from jax._src.ad_util import add_any_p

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, as_scaled_array, register_scaled_op
from jax_scaled_arithmetics.core import (
DTypeLike,
ScaledArray,
as_scaled_array,
get_autoscale_config,
pow2_round,
register_scaled_op,
)

from .scaled_ops_common import check_scalar_scales, promote_scale_types

Expand All @@ -19,12 +26,16 @@ def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArra
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# Pow2 rounding for unit scaling "rule".
pow2_rounding_mode = get_autoscale_config().rounding_mode
# TODO: what happens to `sqrt` for non-floating scale?
# More stable than direct L2 norm, to avoid scale overflow.
ABscale_max = lax.max(A.scale, B.scale)
ABscale_min = lax.min(A.scale, B.scale)
ABscale_ratio = ABscale_min / ABscale_max
output_scale = ABscale_max * lax.sqrt(1 + ABscale_ratio * ABscale_ratio)
# Transform back to power-of-2
output_scale = pow2_round(output_scale, pow2_rounding_mode)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Expand Down Expand Up @@ -63,10 +74,13 @@ def scaled_dot_general(
assert len(lhs_contracting_dims) == 1
assert len(rhs_contracting_dims) == 1

# Pow2 rounding for unit scaling "rule".
pow2_rounding_mode = get_autoscale_config().rounding_mode
contracting_dim_size = lhs.shape[lhs_contracting_dims[0]]
# "unit scaling" rule, based on the contracting axis.
outscale_dtype = jnp.promote_types(lhs.scale.dtype, rhs.scale.dtype)
contracting_rescale = np.sqrt(contracting_dim_size)
contracting_rescale = pow2_round(np.sqrt(contracting_dim_size), pow2_rounding_mode)
# Keeping power of 2 scale.
output_scale = lhs.scale * rhs.scale * contracting_rescale.astype(outscale_dtype)
# NOTE: need to be a bit careful about scale promotion?
output_data = lax.dot_general(
Expand Down Expand Up @@ -94,8 +108,11 @@ def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
axes_size = np.array([shape[idx] for idx in axes])
# Rescale data component following reduction axes.
# Pow2 rounding for unit scaling "rule".
pow2_rounding_mode = get_autoscale_config().rounding_mode
# Rescale data component following reduction axes & round to power of 2 value.
axes_rescale = np.sqrt(np.prod(axes_size))
axes_rescale = pow2_round(axes_rescale, pow2_rounding_mode)
data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale.astype(val.data.dtype)
outscale = val.scale * axes_rescale.astype(val.scale.dtype)
return ScaledArray(data, outscale)
Expand All @@ -107,6 +124,7 @@ def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
shape = val.shape
data = lax.reduce_prod_p.bind(val.data, axes=axes)
axes_size = np.prod(np.array([shape[idx] for idx in axes]))
# Stable for power of 2.
scale = lax.integer_pow(val.scale, axes_size)
return ScaledArray(data, scale)

Expand Down
14 changes: 14 additions & 0 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

from jax_scaled_arithmetics.core import (
Array,
AutoScaleConfig,
Pow2RoundMode,
ScaledArray,
asarray,
autoscale,
get_autoscale_config,
is_scaled_leaf,
register_scaled_op,
scaled_array,
Expand Down Expand Up @@ -229,3 +232,14 @@ def test__promote_scalar_to_scaled_array__promoted_to_scaled_array(self, input):
def test__promote_scalar_to_scaled_array__not_promoted_to_scaled_array(self, input):
out = promote_scalar_to_scaled_array(input)
assert out is input

def test__autoscale_config__default_values(self):
cfg = get_autoscale_config()
assert isinstance(cfg, AutoScaleConfig)
assert cfg.rounding_mode == Pow2RoundMode.DOWN

def test__autoscale_config__context_manager(self):
with AutoScaleConfig(rounding_mode=Pow2RoundMode.NONE):
cfg = get_autoscale_config()
assert isinstance(cfg, AutoScaleConfig)
assert cfg.rounding_mode == Pow2RoundMode.NONE
42 changes: 42 additions & 0 deletions tests/core/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import pow2_round_down, pow2_round_up
from jax_scaled_arithmetics.core.utils import _exponent_bits_mask


class Pow2RoundingUtilTests(chex.TestCase):
@parameterized.parameters(
{"dtype": np.float16},
{"dtype": np.float32},
)
def test__exponent_bitmask__inf_value(self, dtype):
val = _exponent_bits_mask[np.dtype(dtype)].view(dtype)
expected_val = dtype(np.inf)
npt.assert_equal(val, expected_val)

@parameterized.product(
val_exp=[(1, 1), (2.1, 2), (0.3, 0.25), (0.51, 0.5), (65500, 32768)],
dtype=[np.float16, np.float32, np.float64],
)
def test__pow2_round_down__proper_rounding__multi_dtypes(self, val_exp, dtype):
val, exp = dtype(val_exp[0]), dtype(val_exp[1])
pow2_val = pow2_round_down(val)
assert pow2_val.dtype == val.dtype
assert pow2_val.shape == ()
assert type(pow2_val) in {type(val), np.ndarray}
npt.assert_equal(pow2_val, exp)

@parameterized.product(
val_exp=[(2.1, 4), (0.3, 0.5), (0.51, 1), (17000, 32768)],
dtype=[np.float16],
)
def test__pow2_round_up__proper_rounding__multi_dtypes(self, val_exp, dtype):
val, exp = dtype(val_exp[0]), dtype(val_exp[1])
pow2_val = pow2_round_up(val)
assert pow2_val.dtype == val.dtype
assert type(pow2_val) in {type(val), np.ndarray}
npt.assert_equal(pow2_val, exp)
12 changes: 7 additions & 5 deletions tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def setUp(self):
{"ldtype": np.float16, "rdtype": np.float16},
)
def test__scaled_dot_general__proper_scaling(self, ldtype, rdtype):
# Reduction dimension: 5 => sqrt(5) ~ 2
lhs = scaled_array(self.rs.rand(3, 5), 2.0, dtype=ldtype)
rhs = scaled_array(self.rs.rand(5, 2), 3.0, dtype=rdtype)
rhs = scaled_array(self.rs.rand(5, 2), 4.0, dtype=rdtype)

dimension_numbers = (((1,), (0,)), ((), ()))
out = scaled_dot_general(lhs, rhs, dimension_numbers)
Expand All @@ -32,7 +33,7 @@ def test__scaled_dot_general__proper_scaling(self, ldtype, rdtype):
assert isinstance(out, ScaledArray)
assert out.dtype == expected_out.dtype
assert out.scale.dtype == np.float32 # TODO: more test coverage.
npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * np.sqrt(5))
npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * 2)
npt.assert_array_almost_equal(out, expected_out, decimal=2)


Expand Down Expand Up @@ -97,12 +98,13 @@ def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype, sdtyp
)
def test__scaled_addsub__proper_scaling(self, prim):
scaled_op, _ = find_registered_scaled_op(prim)
x = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32)
x = scaled_array([-1.0, 2.0], 4.0, dtype=np.float32)
y = scaled_array([1.5, 4.5], 2.0, dtype=np.float32)
z = scaled_op(x, y)
assert isinstance(z, ScaledArray)
assert z.dtype == x.dtype
npt.assert_almost_equal(z.scale, np.sqrt(4.0 + 9.0))
# Round down to power-of-2
npt.assert_almost_equal(z.scale, 4)

@parameterized.parameters(
{"prim": lax.add_p},
Expand Down Expand Up @@ -142,7 +144,7 @@ def setUp(self):
self.rs = np.random.RandomState(42)

@parameterized.parameters(
{"reduce_prim": lax.reduce_sum_p, "expected_scale": 2 * np.sqrt(5)},
{"reduce_prim": lax.reduce_sum_p, "expected_scale": 2 * 2},
{"reduce_prim": lax.reduce_prod_p, "expected_scale": 2**5},
{"reduce_prim": lax.reduce_min_p, "expected_scale": 2},
{"reduce_prim": lax.reduce_max_p, "expected_scale": 2},
Expand Down

0 comments on commit 2c02f66

Please sign in to comment.