Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (tests): adding tests for FloatQuant #815

Merged
merged 5 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,16 @@ def __init__(
self.bit_width = StatelessBuffer(torch.tensor(float(bit_width), device=device, dtype=dtype))
fabianandresgrob marked this conversation as resolved.
Show resolved Hide resolved
self.signed: bool = signed
self.float_to_int_impl = float_to_int_impl
if exponent_bit_width == 0:
raise RuntimeError("Exponent bit width cannot be 0.")
self.exponent_bit_width = StatelessBuffer(
torch.tensor(float(exponent_bit_width), device=device, dtype=dtype))
if mantissa_bit_width == 0:
raise RuntimeError("Mantissa bit width cannot be 0.")
self.mantissa_bit_width = StatelessBuffer(
(torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype)))
if exponent_bias is None:
exponent_bias = 2 ** (exponent_bit_width - 1) - 1
self.exponent_bias = StatelessBuffer(
torch.tensor(float(exponent_bias), device=device, dtype=dtype))
self.fp_max_val = StatelessBuffer(
Expand Down
158 changes: 158 additions & 0 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from hypothesis import given
import mock
import pytest
import torch

from brevitas.core.function_wrapper import RoundSte
from brevitas.core.quant.float import FloatQuant
from brevitas.core.scaling import ConstScaling
from tests.brevitas.hyp_helper import float_st
from tests.brevitas.hyp_helper import float_tensor_random_shape_st
from tests.brevitas.hyp_helper import random_minifloat_format
from tests.marker import jit_disabled_for_mock


@given(minifloat_format=random_minifloat_format())
def test_float_quant_defaults(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
# specifically don't set exponent bias to see if default works
expected_exponent_bias = 2 ** (exponent_bit_width - 1) - 1
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed)
assert expected_exponent_bias == float_quant.exponent_bias()
assert isinstance(float_quant.float_to_int_impl, RoundSte)
assert isinstance(float_quant.float_scaling_impl, ConstScaling)
assert isinstance(float_quant.scaling_impl, ConstScaling)


@given(minifloat_format=random_minifloat_format())
def test_minifloat(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed)


@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
def test_float_to_quant_float(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed)
expected_out, _, _, bit_width_out = float_quant(inp)

out_quant, scale = float_quant.quantize(inp)
assert bit_width_out == bit_width
assert torch.equal(expected_out, out_quant * scale)


@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
@jit_disabled_for_mock()
def test_scaling_impls_called_once(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
scaling_impl = mock.Mock(side_effect=lambda x: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x: 1.)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
output = float_quant.quantize(inp)
# scaling implementations should be called exaclty once on the input
scaling_impl.assert_called_once_with(inp)
float_scaling_impl.assert_called_once_with(inp)


@given(
inp=float_tensor_random_shape_st(),
minifloat_format=random_minifloat_format(),
scale=float_st())
@jit_disabled_for_mock()
def test_inner_scale(inp, minifloat_format, scale):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
# set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here
scaling_impl = mock.Mock(side_effect=lambda x: scale)
float_scaling_impl = mock.Mock(side_effect=lambda x: 1.)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)

# scale inp manually
scaled_inp = inp / scale

# call internal scale
internal_scale = float_quant.internal_scale(scaled_inp)
val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale)
if signed:
val_fp_quant = torch.clip(
val_fp_quant, -1. * float_quant.fp_max_val(), float_quant.fp_max_val())
else:
val_fp_quant = torch.clip(val_fp_quant, 0., float_quant.fp_max_val())

# dequantize manually
out = val_fp_quant * scale

expected_out, expected_scale, _, _ = float_quant(inp)

assert scale == expected_scale
if scale == 0.0:
# outputs should only receive 0s or nan
assert torch.tensor([
True if val == 0. or val.isnan() else False for val in out.flatten()]).all()
assert torch.tensor([
True if val == 0. or val.isnan() else False for val in expected_out.flatten()
]).all()
else:
# filter out NaN values as we can't compare them
# Note: this still checks if NaN appears at the same values
out_nans = out.isnan()
expected_out_nans = expected_out.isnan()
assert torch.equal(out[~out_nans], expected_out[~expected_out_nans])
21 changes: 21 additions & 0 deletions tests/brevitas/hyp_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch

from tests.brevitas.common import FP32_BIT_WIDTH
from tests.brevitas.common import MAX_INT_BIT_WIDTH
from tests.brevitas.common import MIN_INT_BIT_WIDTH
from tests.conftest import SEED

# Remove Hypothesis check for slow tests and function scoped fixtures.
Expand Down Expand Up @@ -218,3 +220,22 @@ def min_max_tensor_random_shape_st(draw, min_dims=1, max_dims=4, max_size=3, wid
min_tensor = torch.tensor(min_list).view(*shape)
max_tensor = torch.tensor(max_list).view(*shape)
return min_tensor, max_tensor


@st.composite
def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH):
""""
Generate a minifloat format. Returns bit_width, exponent, mantissa, and signed.
fabianandresgrob marked this conversation as resolved.
Show resolved Hide resolved
"""
# TODO: add support for new minifloat format that comes with FloatQuantTensor
bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with))
exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width))
signed = draw(st.booleans())
# if no budget is left, return
fabianandresgrob marked this conversation as resolved.
Show resolved Hide resolved
if bit_width == exponent_bit_width:
return bit_width, exponent_bit_width, 0, False
elif bit_width == (exponent_bit_width + int(signed)):
return bit_width, exponent_bit_width, 0, signed
mantissa_bit_width = bit_width - exponent_bit_width - int(signed)

return bit_width, exponent_bit_width, mantissa_bit_width, signed
Loading