Skip to content

Commit

Permalink
Feat (tests): extended minifloat unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexredd99 committed Jul 1, 2024
1 parent 1394889 commit 27dabb9
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 8 deletions.
32 changes: 26 additions & 6 deletions src/brevitas/utils/float_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch


def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float:
def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False, normal: bool = True) -> float:
# computes the decimal place value from a given binary tensor
res = 1.0
res = 1.0 if normal else 0.0
for i, val in enumerate(bits):
# iterating through from left to right
res += ((2 ** -(i + 1)) * float(val))
Expand All @@ -23,14 +23,20 @@ def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> flo
It expects the exponent and mantissa in their binary format.
"""
exponent_value = int(exponent, 2)
mantissa_value = mantissa_bits_to_float(mantissa)

if exponent_value == 0: # subnormal
exponent_bias -= 1 # exponent is e_min
mantissa_value = mantissa_bits_to_float(mantissa, normal=False)
else: # normal
mantissa_value = mantissa_bits_to_float(mantissa, normal=True)

return (2 ** (exponent_value - exponent_bias)) * mantissa_value


def get_max_available_float(
exponent_bit_width: torch.Tensor,
mantissa_bit_width: torch.Tensor,
exponent_bias: torch.Tensor,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
nan_values: Tuple[str],
inf_values: Tuple[str],
saturating: bool) -> torch.Tensor:
Expand Down Expand Up @@ -75,3 +81,17 @@ def get_max_available_float(
max_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
return max_value


def get_min_available_float(
exponent_bit_width: int, mantissa_bit_width: int, exponent_bias: int) -> torch.Tensor:
"""
Returns the minimum subnormal minifloat value for a given exponent and mantissa
bit-width, and exponent bias.
"""
exponent = '0' * exponent_bit_width
mantissa = '0' * (mantissa_bit_width - 1) + '1'

min_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
return min_value
50 changes: 49 additions & 1 deletion tests/brevitas/core/minifloat_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pytest_cases
from pytest_cases import fixture_union

from brevitas.inject import ExtendedInjector
from brevitas.inject import value
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight

Expand Down Expand Up @@ -32,6 +35,51 @@ class Fp8e5m2(Fp8e5m2OCPWeight):
return Fp8e5m2


list_of_fixtures = ['fp8e4m3', 'fp8e5m2']
class Fp8CustomMixin(ExtendedInjector):
bit_width = 8
saturating = True

hypothesis_internal_is_this_a_mock_check = True

@value
def mantissa_bit_width(bit_width, exponent_bit_width):
return bit_width - exponent_bit_width - 1 # Sign bit


class Fp8e7m0Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 7


class Fp8e6m1Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 6


class Fp8e3m4Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 3


class Fp8e2m5Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 2


class Fp8e1m6Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 1


@pytest_cases.fixture
@pytest_cases.parametrize('exponent_bit_width', [1, 2, 3, 6, 7]) # at least 1 exponent bit
def fp8Custom(exponent_bit_width):

custom_exponents = {
1: Fp8e1m6Weight,
2: Fp8e2m5Weight,
3: Fp8e3m4Weight,
6: Fp8e6m1Weight,
7: Fp8e7m0Weight,}

return custom_exponents[exponent_bit_width]


list_of_fixtures = ['fp8e4m3', 'fp8e5m2', 'fp8Custom']

fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures)
41 changes: 40 additions & 1 deletion tests/brevitas/core/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,41 @@
from brevitas.function.ops import max_float
from brevitas.quant.experimental.float import Fp8e4m3Weight
from brevitas.quant.experimental.float import Fp8e5m2Weight
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeight
from brevitas.quant.experimental.float_quant_fnuz import Fp8e5m2FNUZWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight
from brevitas.utils.float_quant_utils import get_max_available_float
from brevitas.utils.float_quant_utils import get_min_available_float
from tests.brevitas.hyp_helper import float_tensor_random_shape_st

from .minifloat_fixtures import *

FORMAT_MAXVAL_MAP = {
Fp8e5m2OCPWeight: 57344., Fp8e4m3OCPWeight: 448., Fp8e4m3Weight: 480., Fp8e5m2Weight: 114688.}
Fp8e5m2OCPWeight: 57344.,
Fp8e4m3OCPWeight: 448.,
Fp8e4m3Weight: 480.,
Fp8e5m2Weight: 114688.,
Fp8e4m3FNUZWeight: 240.,
Fp8e5m2FNUZWeight: 57344.,
Fp8e7m0Weight: 2.0 ** 64, # Custom exponent_bit_width
Fp8e6m1Weight: 6442450944.0,
Fp8e3m4Weight: 31.0,
Fp8e2m5Weight: 7.875,
Fp8e1m6Weight: 3.96875}

FORMAT_MINVAL_MAP = {
Fp8e5m2OCPWeight: 2.0 ** -16,
Fp8e4m3OCPWeight: 2.0 ** -9,
Fp8e4m3Weight: 2.0 ** -9,
Fp8e5m2Weight: 2.0 ** -16,
Fp8e4m3FNUZWeight: 2.0 ** -10,
Fp8e5m2FNUZWeight: 2.0 ** -17,
Fp8e7m0Weight: 2.0 ** -63, # Custom exponent_bit_width
Fp8e6m1Weight: 2.0 ** -31,
Fp8e3m4Weight: 2.0 ** -6,
Fp8e2m5Weight: 2.0 ** -5,
Fp8e1m6Weight: 2.0 ** -5}


@pytest.mark.parametrize(
Expand All @@ -40,6 +66,19 @@ def test_max_value(minifloat, expected_max_val):
assert expected_max_val == max_val


@pytest.mark.parametrize(
'minifloat, expected_min_val',
((format, min_val) for format, min_val in FORMAT_MINVAL_MAP.items()))
def test_min_value(minifloat, expected_min_val):
min_val = get_min_available_float(
minifloat.exponent_bit_width,
minifloat.mantissa_bit_width,
minifloat.exponent_bias,
)

assert expected_min_val == min_val


@given(inp=float_tensor_random_shape_st())
def test_float_clamp(inp, fp8_clamp):

Expand Down

0 comments on commit 27dabb9

Please sign in to comment.