diff --git a/src/brevitas/utils/float_quant_utils.py b/src/brevitas/utils/float_quant_utils.py index b108c37ee..6b66bf08c 100644 --- a/src/brevitas/utils/float_quant_utils.py +++ b/src/brevitas/utils/float_quant_utils.py @@ -5,9 +5,9 @@ import torch -def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: +def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False, normal: bool = True) -> float: # computes the decimal place value from a given binary tensor - res = 1.0 + res = 1.0 if normal else 0.0 for i, val in enumerate(bits): # iterating through from left to right res += ((2 ** -(i + 1)) * float(val)) @@ -23,14 +23,20 @@ def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> flo It expects the exponent and mantissa in their binary format. """ exponent_value = int(exponent, 2) - mantissa_value = mantissa_bits_to_float(mantissa) + + if exponent_value == 0: # subnormal + exponent_bias -= 1 # exponent is e_min + mantissa_value = mantissa_bits_to_float(mantissa, normal=False) + else: # normal + mantissa_value = mantissa_bits_to_float(mantissa, normal=True) + return (2 ** (exponent_value - exponent_bias)) * mantissa_value def get_max_available_float( - exponent_bit_width: torch.Tensor, - mantissa_bit_width: torch.Tensor, - exponent_bias: torch.Tensor, + exponent_bit_width: int, + mantissa_bit_width: int, + exponent_bias: int, nan_values: Tuple[str], inf_values: Tuple[str], saturating: bool) -> torch.Tensor: @@ -75,3 +81,17 @@ def get_max_available_float( max_value = get_minifloat_value( exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias) return max_value + + +def get_min_available_float( + exponent_bit_width: int, mantissa_bit_width: int, exponent_bias: int) -> torch.Tensor: + """ + Returns the minimum subnormal minifloat value for a given exponent and mantissa + bit-width, and exponent bias. + """ + exponent = '0' * exponent_bit_width + mantissa = '0' * (mantissa_bit_width - 1) + '1' + + min_value = get_minifloat_value( + exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias) + return min_value diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index 681caf8ca..64858f7dd 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -4,6 +4,9 @@ import pytest_cases from pytest_cases import fixture_union +from brevitas.inject import ExtendedInjector +from brevitas.inject import value +from brevitas.quant.experimental.float_base import FloatWeightBase from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight @@ -32,6 +35,51 @@ class Fp8e5m2(Fp8e5m2OCPWeight): return Fp8e5m2 -list_of_fixtures = ['fp8e4m3', 'fp8e5m2'] +class Fp8CustomMixin(ExtendedInjector): + bit_width = 8 + saturating = True + + hypothesis_internal_is_this_a_mock_check = True + + @value + def mantissa_bit_width(bit_width, exponent_bit_width): + return bit_width - exponent_bit_width - 1 # Sign bit + + +class Fp8e7m0Weight(Fp8CustomMixin, FloatWeightBase): + exponent_bit_width = 7 + + +class Fp8e6m1Weight(Fp8CustomMixin, FloatWeightBase): + exponent_bit_width = 6 + + +class Fp8e3m4Weight(Fp8CustomMixin, FloatWeightBase): + exponent_bit_width = 3 + + +class Fp8e2m5Weight(Fp8CustomMixin, FloatWeightBase): + exponent_bit_width = 2 + + +class Fp8e1m6Weight(Fp8CustomMixin, FloatWeightBase): + exponent_bit_width = 1 + + +@pytest_cases.fixture +@pytest_cases.parametrize('exponent_bit_width', [1, 2, 3, 6, 7]) # at least 1 exponent bit +def fp8Custom(exponent_bit_width): + + custom_exponents = { + 1: Fp8e1m6Weight, + 2: Fp8e2m5Weight, + 3: Fp8e3m4Weight, + 6: Fp8e6m1Weight, + 7: Fp8e7m0Weight,} + + return custom_exponents[exponent_bit_width] + + +list_of_fixtures = ['fp8e4m3', 'fp8e5m2', 'fp8Custom'] fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures) diff --git a/tests/brevitas/core/test_clamp.py b/tests/brevitas/core/test_clamp.py index 4b13032f2..96a999494 100644 --- a/tests/brevitas/core/test_clamp.py +++ b/tests/brevitas/core/test_clamp.py @@ -8,15 +8,41 @@ from brevitas.function.ops import max_float from brevitas.quant.experimental.float import Fp8e4m3Weight from brevitas.quant.experimental.float import Fp8e5m2Weight +from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeight +from brevitas.quant.experimental.float_quant_fnuz import Fp8e5m2FNUZWeight from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight from brevitas.utils.float_quant_utils import get_max_available_float +from brevitas.utils.float_quant_utils import get_min_available_float from tests.brevitas.hyp_helper import float_tensor_random_shape_st from .minifloat_fixtures import * FORMAT_MAXVAL_MAP = { - Fp8e5m2OCPWeight: 57344., Fp8e4m3OCPWeight: 448., Fp8e4m3Weight: 480., Fp8e5m2Weight: 114688.} + Fp8e5m2OCPWeight: 57344., + Fp8e4m3OCPWeight: 448., + Fp8e4m3Weight: 480., + Fp8e5m2Weight: 114688., + Fp8e4m3FNUZWeight: 240., + Fp8e5m2FNUZWeight: 57344., + Fp8e7m0Weight: 2.0 ** 64, # Custom exponent_bit_width + Fp8e6m1Weight: 6442450944.0, + Fp8e3m4Weight: 31.0, + Fp8e2m5Weight: 7.875, + Fp8e1m6Weight: 3.96875} + +FORMAT_MINVAL_MAP = { + Fp8e5m2OCPWeight: 2.0 ** -16, + Fp8e4m3OCPWeight: 2.0 ** -9, + Fp8e4m3Weight: 2.0 ** -9, + Fp8e5m2Weight: 2.0 ** -16, + Fp8e4m3FNUZWeight: 2.0 ** -10, + Fp8e5m2FNUZWeight: 2.0 ** -17, + Fp8e7m0Weight: 2.0 ** -63, # Custom exponent_bit_width + Fp8e6m1Weight: 2.0 ** -31, + Fp8e3m4Weight: 2.0 ** -6, + Fp8e2m5Weight: 2.0 ** -5, + Fp8e1m6Weight: 2.0 ** -5} @pytest.mark.parametrize( @@ -40,6 +66,19 @@ def test_max_value(minifloat, expected_max_val): assert expected_max_val == max_val +@pytest.mark.parametrize( + 'minifloat, expected_min_val', + ((format, min_val) for format, min_val in FORMAT_MINVAL_MAP.items())) +def test_min_value(minifloat, expected_min_val): + min_val = get_min_available_float( + minifloat.exponent_bit_width, + minifloat.mantissa_bit_width, + minifloat.exponent_bias, + ) + + assert expected_min_val == min_val + + @given(inp=float_tensor_random_shape_st()) def test_float_clamp(inp, fp8_clamp):