diff --git a/jax_scalify/core/pow2.py b/jax_scalify/core/pow2.py index a9f1edc..ec2fe6d 100644 --- a/jax_scalify/core/pow2.py +++ b/jax_scalify/core/pow2.py @@ -4,6 +4,7 @@ from functools import partial from typing import Any, Dict, Optional, Sequence, Tuple, Union +import jax.numpy as jnp import numpy as np from jax import core from jax.interpreters import mlir @@ -14,6 +15,9 @@ # Exponent bits masking. _exponent_bits_mask: Dict[Any, NDArray[Any]] = { + np.dtype(jnp.bfloat16): np.packbits( + np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8) + ).view(np.int16), np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view( np.int16 ), @@ -31,6 +35,24 @@ """ +def dtype_exponent_mask(dtype: DTypeLike, sign_bit: bool = False) -> NDArray[Any]: + """Get the exponent mask for a given Numpy/JAX dtype. + + Args: + dtype: Numpy/JAX dtype. + sign_bit: Include sign bit in the mask. + Returns: + Array mask as integer dtype. + """ + mask = _exponent_bits_mask[dtype] + if sign_bit: + # Negative value to add sign. + intdtype = mask.dtype + mask = (-mask.view(dtype)).view(intdtype) + return mask + return mask + + def pow2_decompose_round_down_impl(vin: Array, scale_dtype: DTypeLike) -> Array: """Pow-2 decompose with rounding down. @@ -42,7 +64,7 @@ def pow2_decompose_round_down_impl(vin: Array, scale_dtype: DTypeLike) -> Array: # NOTE: `jnp.frexp` is buggy for subnormals. dtype = np.dtype(np.float32) minval = np.finfo(dtype).smallest_normal - exponent_mask = _exponent_bits_mask[dtype] + exponent_mask = dtype_exponent_mask(dtype) intdtype = exponent_mask.dtype val = vin.astype(dtype) # Masking mantissa bits, keeping only the exponents ones. diff --git a/jax_scalify/quantization/__init__.py b/jax_scalify/quantization/__init__.py new file mode 100644 index 0000000..83707b8 --- /dev/null +++ b/jax_scalify/quantization/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +from .scale import as_e8m0 # noqa: F401 diff --git a/jax_scalify/quantization/scale.py b/jax_scalify/quantization/scale.py new file mode 100644 index 0000000..330d30f --- /dev/null +++ b/jax_scalify/quantization/scale.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +import jax.numpy as jnp +import numpy as np + +from jax_scalify.core import Array, DTypeLike, get_numpy_api +from jax_scalify.core.pow2 import dtype_exponent_mask + + +def pow2_truncate(arr: Array) -> Array: + """Convert an Array to a power of 2, using mantissa truncation. + + NOTE: all sub-normals values are flushed to zero. + """ + np_api = get_numpy_api(arr) + # Masking mantissa & sign-bit, keeping only exponent values. + exponent_mask = dtype_exponent_mask(arr.dtype, sign_bit=True) + intdtype = exponent_mask.dtype + # Masking mantissa bits, keeping only the exponents ones. + arr_pow2 = np_api.bitwise_and(arr.view(intdtype), exponent_mask).view(arr.dtype).reshape(arr.shape) + return arr_pow2 + + +def as_e8m0(arr: Array) -> Array: + """Convert an Array to e8m0 format (i.e. power of two values). + + This function is only implementing a truncation + saturation variant, in line with + the MX OCP format. + + Args: + arr: Input array (FP16, FP32 or BF16). + Returns: + E8M0 array (as uint8). + """ + np_api = get_numpy_api(arr) + # assert len(arr.shape) < 2 + assert arr.dtype in {np.dtype(jnp.bfloat16), np.dtype(jnp.float32)} + # Saturation => negative values saturating to min value (i.e. zero bits) in E8M0. + arr = np_api.maximum(arr, np.array(0, arr.dtype)) + arr = pow2_truncate(arr) + + # Bit masking to extract the exponent as uint8 array. + arr_u8 = arr.view(np.uint8).reshape((*arr.shape, -1)) + arr_e8m0 = np_api.bitwise_or(np_api.left_shift(arr_u8[..., -1], 1), np_api.right_shift(arr_u8[..., -2], 7)) + return arr_e8m0 + + +def from_e8m0(arr: Array, dtype: DTypeLike) -> Array: + """Convert an Array of e8m0 values (i.e. power of two values) to a given dtype. + + Args: + arr: E8M0 array (assuming uint8 storage dtype). + dtype: Output dtype. FP32 or BF16 supported. + Returns: + Converted output. + """ + np_api = get_numpy_api(arr) + assert arr.dtype == np.uint8 + assert np.dtype(dtype) in {np.dtype(jnp.bfloat16), np.dtype(jnp.float32)} + # Avoid issues with 7 mantissa bits in BF16. + # TODO: more efficient implementation! + arr = np_api.exp2(arr.astype(np.float32) - 127) + return arr.astype(dtype) diff --git a/tests/quantization/test_scale.py b/tests/quantization/test_scale.py new file mode 100644 index 0000000..fd11c63 --- /dev/null +++ b/tests/quantization/test_scale.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +import chex +import ml_dtypes +import numpy as np +import numpy.testing as npt +from absl.testing import parameterized + +from jax_scalify.quantization.scale import as_e8m0, from_e8m0, pow2_truncate + + +class QuantizationScaleTests(chex.TestCase): + @parameterized.parameters( + {"dtype": np.float16}, + {"dtype": np.float32}, + {"dtype": ml_dtypes.bfloat16}, + ) + def test__pow2_truncate__proper_result(self, dtype): + vin = np.array([-2, 0, 2, 1, 9, 15]).astype(dtype) + vout = pow2_truncate(vin) + assert vout.dtype == vin.dtype + npt.assert_array_equal(vout, [-2.0, 0.0, 2.0, 1.0, 8.0, 8.0]) + + @parameterized.parameters( + # {"dtype": np.float16}, + {"dtype": np.float32}, + {"dtype": ml_dtypes.bfloat16}, + ) + def test__as_e8m0__positive_values(self, dtype): + vin = np.array([0.6, 2, 1, 9, 15, 127]).astype(dtype).reshape((-1, 2)) + vout = as_e8m0(vin) + assert vout.dtype == np.uint8 + assert vout.shape == vin.shape + npt.assert_array_equal(vout, np.log2(pow2_truncate(vin)) + 127) + + @parameterized.parameters( + # {"dtype": np.float16}, + {"dtype": np.float32}, + {"dtype": ml_dtypes.bfloat16}, + ) + def test__as_e8m0__negative_values(self, dtype): + vin = np.array([-0.1, -3, 0, 2**-127]).astype(dtype) + vout = as_e8m0(vin) + assert vout.dtype == np.uint8 + # NOTE: uint8(0) is the smallest positive scale in E8M0. + npt.assert_array_equal(vout, np.uint8(0)) + + @parameterized.parameters( + # {"dtype": np.float16}, + {"dtype": np.float32}, + {"dtype": ml_dtypes.bfloat16}, + ) + def test__from_e8m0(self, dtype): + vin = np.array([2**-127, 0.25, 1, 2, 8, 2**127.0]).astype(dtype).reshape((-1, 2)) + vin_e8m0 = as_e8m0(vin) + vout = from_e8m0(vin_e8m0, dtype) + assert vin.dtype == vout.dtype + assert vout.shape == vin.shape + npt.assert_array_equal(vout, vin)