Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (tests): extended minifloat unit tests #979

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions src/brevitas/utils/float_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch


def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float:
def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False, normal: bool = True) -> float:
# computes the decimal place value from a given binary tensor
res = 1.0
res = 1.0 if normal else 0.0
for i, val in enumerate(bits):
# iterating through from left to right
res += ((2 ** -(i + 1)) * float(val))
Expand All @@ -23,14 +23,20 @@ def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> flo
It expects the exponent and mantissa in their binary format.
"""
exponent_value = int(exponent, 2)
mantissa_value = mantissa_bits_to_float(mantissa)

if exponent_value == 0: # subnormal
exponent_bias -= 1 # exponent is e_min
mantissa_value = mantissa_bits_to_float(mantissa, normal=False)
else: # normal
mantissa_value = mantissa_bits_to_float(mantissa, normal=True)

return (2 ** (exponent_value - exponent_bias)) * mantissa_value


def get_max_available_float(
exponent_bit_width: torch.Tensor,
mantissa_bit_width: torch.Tensor,
exponent_bias: torch.Tensor,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
nan_values: Tuple[str],
inf_values: Tuple[str],
saturating: bool) -> torch.Tensor:
Expand Down Expand Up @@ -75,3 +81,17 @@ def get_max_available_float(
max_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
return max_value


def get_min_available_float(
exponent_bit_width: int, mantissa_bit_width: int, exponent_bias: int) -> torch.Tensor:
"""
Returns the minimum subnormal minifloat value for a given exponent and mantissa
bit-width, and exponent bias.
"""
exponent = '0' * exponent_bit_width
mantissa = '0' * (mantissa_bit_width - 1) + '1'

min_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
return min_value
50 changes: 49 additions & 1 deletion tests/brevitas/core/minifloat_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pytest_cases
from pytest_cases import fixture_union

from brevitas.inject import ExtendedInjector
from brevitas.inject import value
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight

Expand Down Expand Up @@ -32,6 +35,51 @@ class Fp8e5m2(Fp8e5m2OCPWeight):
return Fp8e5m2


list_of_fixtures = ['fp8e4m3', 'fp8e5m2']
class Fp8CustomMixin(ExtendedInjector):
bit_width = 8
saturating = True

hypothesis_internal_is_this_a_mock_check = True

@value
def mantissa_bit_width(bit_width, exponent_bit_width):
return bit_width - exponent_bit_width - 1 # Sign bit


class Fp8e7m0Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 7


class Fp8e6m1Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 6


class Fp8e3m4Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 3


class Fp8e2m5Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 2


class Fp8e1m6Weight(Fp8CustomMixin, FloatWeightBase):
exponent_bit_width = 1


@pytest_cases.fixture
@pytest_cases.parametrize('exponent_bit_width', [1, 2, 3, 6, 7]) # at least 1 exponent bit
def fp8Custom(exponent_bit_width):

custom_exponents = {
1: Fp8e1m6Weight,
2: Fp8e2m5Weight,
3: Fp8e3m4Weight,
6: Fp8e6m1Weight,
7: Fp8e7m0Weight,}

return custom_exponents[exponent_bit_width]


list_of_fixtures = ['fp8e4m3', 'fp8e5m2', 'fp8Custom']

fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures)
41 changes: 40 additions & 1 deletion tests/brevitas/core/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,41 @@
from brevitas.function.ops import max_float
from brevitas.quant.experimental.float import Fp8e4m3Weight
from brevitas.quant.experimental.float import Fp8e5m2Weight
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeight
from brevitas.quant.experimental.float_quant_fnuz import Fp8e5m2FNUZWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight
from brevitas.utils.float_quant_utils import get_max_available_float
from brevitas.utils.float_quant_utils import get_min_available_float
from tests.brevitas.hyp_helper import float_tensor_random_shape_st

from .minifloat_fixtures import *

FORMAT_MAXVAL_MAP = {
Fp8e5m2OCPWeight: 57344., Fp8e4m3OCPWeight: 448., Fp8e4m3Weight: 480., Fp8e5m2Weight: 114688.}
Fp8e5m2OCPWeight: 57344.,
Fp8e4m3OCPWeight: 448.,
Fp8e4m3Weight: 480.,
Fp8e5m2Weight: 114688.,
Fp8e4m3FNUZWeight: 240.,
Fp8e5m2FNUZWeight: 57344.,
Fp8e7m0Weight: 2.0 ** 64, # Custom exponent_bit_width
Fp8e6m1Weight: 6442450944.0,
Fp8e3m4Weight: 31.0,
Fp8e2m5Weight: 7.875,
Fp8e1m6Weight: 3.96875}

FORMAT_MINVAL_MAP = {
Fp8e5m2OCPWeight: 2.0 ** -16,
Fp8e4m3OCPWeight: 2.0 ** -9,
Fp8e4m3Weight: 2.0 ** -9,
Fp8e5m2Weight: 2.0 ** -16,
Fp8e4m3FNUZWeight: 2.0 ** -10,
Fp8e5m2FNUZWeight: 2.0 ** -17,
Fp8e7m0Weight: 2.0 ** -63, # Custom exponent_bit_width
Fp8e6m1Weight: 2.0 ** -31,
Fp8e3m4Weight: 2.0 ** -6,
Fp8e2m5Weight: 2.0 ** -5,
Fp8e1m6Weight: 2.0 ** -5}


@pytest.mark.parametrize(
Expand All @@ -40,6 +66,19 @@ def test_max_value(minifloat, expected_max_val):
assert expected_max_val == max_val


@pytest.mark.parametrize(
'minifloat, expected_min_val',
((format, min_val) for format, min_val in FORMAT_MINVAL_MAP.items()))
def test_min_value(minifloat, expected_min_val):
min_val = get_min_available_float(
minifloat.exponent_bit_width,
minifloat.mantissa_bit_width,
minifloat.exponent_bias,
)

assert expected_min_val == min_val


@given(inp=float_tensor_random_shape_st())
def test_float_clamp(inp, fp8_clamp):

Expand Down
Loading