Skip to content

Commit

Permalink
Fix (minifloat): restructure OCP format quantizers
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Feb 26, 2024
1 parent 80d118b commit cf35b91
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 42 deletions.
12 changes: 6 additions & 6 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ def exponent_bias(exponent_bit_width):

@value
def max_value(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values,
saturating):
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values=None,
inf_values=None,
saturating=True):
return get_max_value(
exponent_bit_width,
mantissa_bit_width,
Expand Down Expand Up @@ -67,15 +71,11 @@ class Fp8e4m3Mixin(ExtendedInjector):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
nan_values = (('111',))
inf_values = None
saturating = True


class Fp8e5m2Mixin(ExtendedInjector):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
nan_values = ('01', '11', '10')
inf_values = (('00',))
saturating = True
150 changes: 150 additions & 0 deletions src/brevitas/quant/experimental/float_quant_ocp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.experimental.float_base import FloatActBase
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from brevitas.quant.experimental.float_base import ScaledFloatActBase
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase


class Fp8e4m3OCPMixin(Fp8e4m3Mixin):
nan_values = (('111',))
inf_values = None


class Fp8e5m2OCPMixin(Fp8e5m2Mixin):
nan_values = ('01', '11', '10')
inf_values = (('00',))


class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase):
"""
FP8 signed E3M4 weight quantizer.
"""
pass


class Fp8e5m2OCPWeight(Fp8e5m2OCPMixin, FloatWeightBase):
"""
FP8 signed E5M2 weight quantizer.
"""
pass


class Fp8e4m3OCPAct(Fp8e4m3OCPMixin, FloatActBase):
"""
FP8 signed E4M3 activation quantizer.
"""
pass


class Fp8e5m2OCPAct(Fp8e5m2OCPMixin, FloatActBase):
"""
FP8 signed E5M2 activation quantizer.
"""
pass


class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2OCPWeightPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3OCPActPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2OCPActPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3OCPWeightPerChannelFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e5m2OCPWeightPerChannelFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3OCPActPerChannelFloat2d(Fp8e4m3OCPMixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2OCPActPerChannelFloat2d(Fp8e5m2OCPMixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3OCPActPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2OCPActPerTensorFloatMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3OCPActPerChannelFloat2dMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2OCPActPerChannelFloat2dMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3OCPWeightPerChannelFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3OCPWeightPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling.
"""
scaling_per_output_channel = False
46 changes: 11 additions & 35 deletions tests/brevitas/core/minifloat_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,34 @@
import pytest_cases
from pytest_cases import fixture_union

from brevitas.inject.enum import BitWidthImplType
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase


class Fp8e4m3Base(ScaledFloatWeightBase):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
nan_values = None
inf_values = None
saturating = True
bit_width_impl_type = BitWidthImplType.CONST
# hypothesis extra
hypothesis_internal_is_this_a_mock_check = False


class Fp8e5m2Base(ScaledFloatWeightBase):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
nan_values = None
inf_values = None
saturating = True
bit_width_impl_type = BitWidthImplType.CONST
# hypothesis extra
hypothesis_internal_is_this_a_mock_check = False
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight


@pytest_cases.fixture
@pytest_cases.parametrize('sat', [True, False])
def fp8e4m3_regular(sat):
def fp8e4m3(sat):

class Fp8e4m3(Fp8e4m3Base):
class Fp8e4m3(Fp8e4m3OCPWeight):
saturating = sat
nan_values = tuple(('111',))
inf_values = None
# for hypothesis and DI
hypothesis_internal_is_this_a_mock_check = True

return Fp8e4m3


@pytest_cases.fixture
@pytest_cases.parametrize('sat', [True, False])
def fp8e5m2_regular(sat):
def fp8e5m2(sat):

class Fp8e5m2(Fp8e5m2Base):
class Fp8e5m2(Fp8e5m2OCPWeight):
saturating = sat
nan_values = ('01', '11', '10')
inf_values = tuple(('00',))
# for hypothesis and DI
hypothesis_internal_is_this_a_mock_check = True

return Fp8e5m2


list_of_fixtures = ['fp8e4m3_regular', 'fp8e5m2_regular']
list_of_fixtures = ['fp8e4m3', 'fp8e5m2']

fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures)
5 changes: 4 additions & 1 deletion tests/brevitas/core/test_minifloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from brevitas.quant.experimental.float import Fp8e4m3Weight
from brevitas.quant.experimental.float import Fp8e5m2Weight
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight
from tests.brevitas.hyp_helper import float_tensor_random_shape_st

from .minifloat_fixtures import *

FORMATS = {Fp8e5m2Weight: 57344., Fp8e4m3Weight: 448., Fp8e4m3Base: 480., Fp8e5m2Base: 114688.}
FORMATS = {
Fp8e5m2OCPWeight: 57344., Fp8e4m3OCPWeight: 448., Fp8e4m3Weight: 480., Fp8e5m2Weight: 114688.}


@pytest.mark.parametrize(
Expand Down

0 comments on commit cf35b91

Please sign in to comment.