Skip to content

Commit

Permalink
Fix (minifloat): restructuring quantizers
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Feb 23, 2024
1 parent 254a91f commit 8f599cd
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 47 deletions.
2 changes: 0 additions & 2 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
import torch.nn as nn

import brevitas
from brevitas.core.function_wrapper import FloatClamp
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.scaling import ConstScaling
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float
from brevitas.function.ops_ste import floor_ste
from brevitas.utils.float_quant_utils import get_max_value


class FloatQuant(brevitas.jit.ScriptModule):
Expand Down
61 changes: 26 additions & 35 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,37 @@
from brevitas.utils.float_quant_utils import get_max_value


class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum):
proxy_class = WeightQuantProxyFromInjector
class FloatBase(SolveTensorQuantFloatToIntImplFromEnum):
tensor_quant = FloatQuant
signed = True
float_to_int_impl_type = 'round'
scaling_min_val = 1e-10
float_clamp_impl = FloatClamp
tensor_clamp_impl = TensorClamp

@value
def exponent_bias(exponent_bit_width):
return 2 ** (exponent_bit_width - 1) - 1

class FloatActBase(SolveTensorQuantFloatToIntImplFromEnum):
@value
def max_value(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values,
saturating):
return get_max_value(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values,
inf_values,
saturating)


class FloatWeightBase(FloatBase):
proxy_class = WeightQuantProxyFromInjector


class FloatActBase(FloatBase):
proxy_class = ActQuantProxyFromInjector
tensor_quant = FloatQuant
signed = True
float_to_int_impl_type = 'round'
scaling_min_val = 1e-10


class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver):
Expand All @@ -46,45 +63,19 @@ class ScaledFloatActBase(FloatActBase, ActQuantSolver):
float_scaling_impl = FloatScaling


class ExponentBiasMixin(ExtendedInjector):

@value
def exponent_bias(exponent_bit_width):
return 2 ** (exponent_bit_width - 1) - 1


class MaxFloatInfNaNMixin(ExtendedInjector):

@value
def max_value(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values,
saturating):
return get_max_value(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values,
inf_values,
saturating)


class Fp8e4m3Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin):
class Fp8e4m3Mixin(ExtendedInjector):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
float_clamp_impl = FloatClamp
tensor_clamp_impl = TensorClamp
nan_values = (('111',))
inf_values = None
saturating = True


class Fp8e5m2Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin):
class Fp8e5m2Mixin(ExtendedInjector):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
float_clamp_impl = FloatClamp
tensor_clamp_impl = TensorClamp
nan_values = ('01', '11', '10')
inf_values = (('00',))
saturating = True
9 changes: 2 additions & 7 deletions tests/brevitas/core/minifloat_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
import pytest_cases
from pytest_cases import fixture_union

from brevitas.core.function_wrapper import FloatClamp
from brevitas.inject.enum import BitWidthImplType
from brevitas.quant.experimental.float_base import ExponentBiasMixin
from brevitas.quant.experimental.float_base import MaxFloatInfNaNMixin
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase


class Fp8e4m3Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase):
class Fp8e4m3Base(ScaledFloatWeightBase):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
float_clamp_impl = FloatClamp
nan_values = None
inf_values = None
saturating = True
Expand All @@ -24,11 +20,10 @@ class Fp8e4m3Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase)
hypothesis_internal_is_this_a_mock_check = False


class Fp8e5m2Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase):
class Fp8e5m2Base(ScaledFloatWeightBase):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
float_clamp_impl = FloatClamp
nan_values = None
inf_values = None
saturating = True
Expand Down
6 changes: 3 additions & 3 deletions tests/brevitas/core/test_minifloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from hypothesis import given
import pytest

from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from brevitas.quant.experimental.float import Fp8e4m3Weight
from brevitas.quant.experimental.float import Fp8e5m2Weight
from tests.brevitas.hyp_helper import float_tensor_random_shape_st

from .minifloat_fixtures import *

FORMATS = {Fp8e5m2Mixin: 57344., Fp8e4m3Mixin: 448., Fp8e4m3Base: 480., Fp8e5m2Base: 114688.}
FORMATS = {Fp8e5m2Weight: 57344., Fp8e4m3Weight: 448., Fp8e4m3Base: 480., Fp8e5m2Base: 114688.}


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8f599cd

Please sign in to comment.