Skip to content

Commit

Permalink
Improve make_scaled_scalar with optional scale dtype parameter. (#93)
Browse files Browse the repository at this point in the history
Allowing to pass directly the scale dtype (e.g. FP32) improves subnormal
support, instead of flushing to zero.
  • Loading branch information
balancap authored Jan 25, 2024
1 parent 478cf62 commit 9b862ef
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 25 deletions.
26 changes: 19 additions & 7 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def astype(self, dtype) -> "ScaledArray":
return ScaledArray(self.data.astype(dtype), self.scale)


def make_scaled_scalar(val: Array) -> ScaledArray:
def make_scaled_scalar(val: Array, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
"""Make a scaled scalar (array), from a single value.
The returned scalar will always be built such that:
Expand All @@ -118,8 +118,11 @@ def make_scaled_scalar(val: Array) -> ScaledArray:
val = np.float32(val)
assert np.ndim(val) == 0
assert np.issubdtype(val.dtype, np.floating)
# Scale dtype to use.
# TODO: check the scale dtype?
scale_dtype = scale_dtype or val.dtype
# Split mantissa and exponent in data and scale components.
scale = pow2_round_down(val)
scale = pow2_round_down(val.astype(scale_dtype))
npapi = get_numpy_api(scale)
return ScaledArray(npapi.asarray(get_mantissa(val)), scale)

Expand Down Expand Up @@ -155,8 +158,16 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa
return scaled_array_base(data, scale, dtype, npapi)


def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[Array, ScaledArray]:
"""ScaledArray (helper) base factory method, similar to `(j)np.array`."""
def as_scaled_array_base(
val: Any, scale: Optional[ArrayLike] = None, scale_dtype: Optional[DTypeLike] = None
) -> Union[Array, ScaledArray]:
"""ScaledArray (helper) base factory method, similar to `(j)np.array`.
Args:
val: Value to convert to scaled array.
scale: Optional scale value.
scale_dtype: Optional (default) scale dtype.
"""
if isinstance(val, ScaledArray):
return val

Expand All @@ -166,17 +177,18 @@ def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[A
if is_static_one_scale and isinstance(val, (bool, int)):
return val
if is_static_one_scale and isinstance(val, float):
return make_scaled_scalar(np.float32(val))
return make_scaled_scalar(np.float32(val), scale_dtype)

# Ignored dtypes by default: int and bool
ignored_dtype = np.issubdtype(val.dtype, np.integer) or np.issubdtype(val.dtype, np.bool_)
if ignored_dtype:
return val
# Floating point scalar
if val.ndim == 0 and is_static_one_scale:
return make_scaled_scalar(val)
return make_scaled_scalar(val, scale_dtype)

scale = np.array(1, dtype=val.dtype) if scale is None else scale
scale_dtype = scale_dtype or val.dtype
scale = np.array(1, dtype=scale_dtype) if scale is None else scale
if isinstance(val, (np.ndarray, Array)):
if is_static_one_scale:
return ScaledArray(val, scale)
Expand Down
27 changes: 9 additions & 18 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from enum import IntEnum
from functools import partial, wraps
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple

import jax
import numpy as np
Expand All @@ -15,7 +15,7 @@
)
from jax._src.util import safe_map

from .datatype import Array, DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .datatype import Array, DTypeLike, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode


Expand Down Expand Up @@ -96,24 +96,13 @@ def _get_data(val: Any) -> Array:
return val


def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:
def promote_scalar_to_scaled_array(val: Any, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
"""Promote a scalar (Numpy, JAX, ...) to a Scaled Array.
Note: needs to work with any input type, including JAX tracer ones.
"""
# Use `as_scaled_array` promotion rules.
return as_scaled_array_base(val)


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
"""Get the ScaledArray corresponding to a Numpy constant.
Only supporting Numpy scalars at the moment.
"""
# TODO: generalized rules!
assert np.ndim(val) == 0
assert np.issubdtype(val.dtype, np.floating)
return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val))
return as_scaled_array_base(val, scale_dtype=scale_dtype)


def register_scaled_op(
Expand Down Expand Up @@ -200,6 +189,8 @@ def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}
# Check dtype consistency between normal and scaled modes.
safe_check_dtypes: bool = False
# AutoScale config to use.
autoscale_cfg = get_autoscale_config()

def read(var):
if type(var) is core.Literal:
Expand All @@ -209,11 +200,11 @@ def read(var):
def write(var, val):
env[var] = val

def promote_to_scaled_array(val):
def promote_to_scaled_array(val, scale_dtype):
if isinstance(val, ScaledArray):
return val
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val)
return promote_scalar_to_scaled_array(val, scale_dtype)
# No promotion rule => just return as such.
return val

Expand Down Expand Up @@ -245,7 +236,7 @@ def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Se
)
else:
# Using scaled primitive. Automatic promotion of inputs to scaled array, when possible.
scaled_invals = list(map(promote_to_scaled_array, invals))
scaled_invals = list(map(lambda v: promote_to_scaled_array(v, autoscale_cfg.scale_dtype), invals))
outvals = scaled_prim_fn(*scaled_invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
Expand Down
20 changes: 20 additions & 0 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,26 @@ def test__make_scaled_scalar__zero_scalar_input(self, val):
assert scaled_val.shape == ()
assert scaled_val.dtype == val.dtype

def test__make_scaled_scalar__optional_scale_dtype(self):
val = np.float16(0.25)
scaled_val = make_scaled_scalar(val, scale_dtype=np.float32)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.dtype == val.dtype
assert scaled_val.scale.dtype == np.float32
npt.assert_equal(np.asarray(scaled_val), val)

@parameterized.parameters(
{"val": np.finfo(np.float16).smallest_normal},
{"val": np.finfo(np.float16).smallest_subnormal},
{"val": np.float16(3.123283386230469e-05)},
)
def test__make_scaled_scalar__fp16_subnormal_support(self, val):
# Use FP32 scale dtype, to have enough range.
# NOTE: failing in FP16!
scaled_val = make_scaled_scalar(val, scale_dtype=np.float32)
# No loss of information when converting everything to FP32.
npt.assert_equal(np.asarray(scaled_val, dtype=np.float32), np.float32(val))

@parameterized.parameters(
{"val": np.array(1.0)},
{"val": np.float32(-0.5)},
Expand Down
17 changes: 17 additions & 0 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,20 @@ def test__autoscale_config__context_manager(self):
assert isinstance(cfg, AutoScaleConfig)
assert cfg.rounding_mode == Pow2RoundMode.NONE
assert cfg.scale_dtype == np.float32

def test__autoscale_config__scale_dtype_used_in_interpreter_promotion(self):
def fn(x):
# Underflowing to zero in `autoscale` mode if scale_dtype == np.float16.
return x * 3.123283386230469e-05

scaled_input = scaled_array(np.array(2.0, np.float16), scale=np.float32(0.5))
expected_output = fn(np.float16(1))

with AutoScaleConfig(scale_dtype=np.float32):
scaled_output = autoscale(fn)(scaled_input)
assert scaled_output.scale.dtype == np.float32
npt.assert_equal(np.asarray(scaled_output, dtype=np.float32), expected_output)

with AutoScaleConfig(scale_dtype=np.float16):
scaled_output = autoscale(fn)(scaled_input)
npt.assert_almost_equal(scaled_output.scale, 0)

0 comments on commit 9b862ef

Please sign in to comment.