From 335e51458261e90a9a7515867aae2de894af38c5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 18 Jun 2024 13:37:25 +0100 Subject: [PATCH] Feat (minifloat): support for FNUZ variants --- src/brevitas/export/common/handler/base.py | 19 +- src/brevitas/export/common/handler/qcdq.py | 53 ++++-- .../export/onnx/standard/qcdq/handler.py | 14 +- src/brevitas/proxy/float_parameter_quant.py | 22 +++ src/brevitas/proxy/float_runtime_quant.py | 22 +++ .../quant/experimental/float_quant_fnuz.py | 163 ++++++++++++++++++ tests/brevitas_ort/common.py | 2 +- tests/brevitas_ort/quant_module_cases.py | 7 +- 8 files changed, 261 insertions(+), 41 deletions(-) create mode 100644 src/brevitas/quant/experimental/float_quant_fnuz.py diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index bf3f69ed4..0841b7f21 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -124,13 +124,20 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor): class FloatZeroPointHandlerMixin(ABC): @classmethod - def zero_point_with_dtype(cls, exponent_bit_width, mantissa_bit_width, zero_point): - if exponent_bit_width == 4 and mantissa_bit_width == 3: - return zero_point.type(torch.float8_e4m3fn) - elif exponent_bit_width == 5 and mantissa_bit_width == 2: - return zero_point.type(torch.float8_e5m2) + def zero_point_with_dtype( + cls, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz, zero_point): + if is_ocp: + if exponent_bit_width == 4 and mantissa_bit_width == 3: + return zero_point.type(torch.float8_e4m3fn) + elif exponent_bit_width == 5 and mantissa_bit_width == 2: + return zero_point.type(torch.float8_e5m2) + elif is_fnuz: + if exponent_bit_width == 4 and mantissa_bit_width == 3: + return zero_point.type(torch.float8_e4m3fnuz) + elif exponent_bit_width == 5 and mantissa_bit_width == 2: + return zero_point.type(torch.float8_e5m2fnuz) else: - return zero_point.type(torch.float32) + raise NotImplementedError class ZeroPointHandlerMixin(ABC): diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 9a91d1f5a..44061ce42 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -78,15 +78,21 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): pass @classmethod - def signed_dtype(cls, exponent_bit_width, mantissa_bit_width): + def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz): if exponent_bit_width is None or mantissa_bit_width is None: return None - if exponent_bit_width == 4 and mantissa_bit_width == 3: - dtype = torch.float8_e4m3fn - elif exponent_bit_width == 5 and mantissa_bit_width == 2: - dtype = torch.float8_e5m2 + if is_ocp: + if exponent_bit_width == 4 and mantissa_bit_width == 3: + dtype = torch.float8_e4m3fn + elif exponent_bit_width == 5 and mantissa_bit_width == 2: + dtype = torch.float8_e5m2 + elif is_fnuz: + if exponent_bit_width == 4 and mantissa_bit_width == 3: + dtype = torch.float8_e4m3fnuz + elif exponent_bit_width == 5 and mantissa_bit_width == 2: + dtype = torch.float8_e5m2fnuz else: - dtype = torch.float32 + raise NotImplementedError return dtype @@ -140,7 +146,8 @@ class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, CDQCastMixin, ABC): - def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): + def dequantize_symbolic_kwargs( + cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz): scale_orig_shape = scale.shape axis = cls.quant_axis(scale) if cls.flatten_dequantize_params: @@ -150,7 +157,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, manti zero_point = zero_point.flatten() zp = to_0dim_if_scalar(zero_point) zp = zp.expand_as(scale) - zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz, zp) return { 'scale': scale, 'zero_point': zp, @@ -187,18 +194,19 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): class FloatQCDQCastWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): handled_layer = WeightFloatQuantProxyFromInjector - def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): + def quantize_symbolic_kwargs( + cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz): # compute axis before redefining scale axis = cls.quant_axis(scale) scale = to_0dim_if_scalar(scale.flatten()) zp = to_0dim_if_scalar(zero_point.flatten()) # expand_as must go after 0-dim check zp = zp.expand_as(scale) - zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz, zp) if cls.itemize_quantize_scalar_params: scale = to_item_if_0dim(scale) zp = to_item_if_0dim(zp) - dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width) + dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz) return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} def prepare_quantize_from_floating_point(self, module): @@ -211,7 +219,9 @@ def prepare_quantize_from_floating_point(self, module): scale, quant_weight.zero_point, quant_weight.exponent_bit_width, - quant_weight.mantissa_bit_width) + quant_weight.mantissa_bit_width, + module.is_ocp, + module.is_fnuz) def prepare_quantize_from_minifloat(self, module): raise NotImplementedError @@ -249,7 +259,9 @@ def prepare_for_export(self, module): scale, quant_weight.zero_point, quant_weight.exponent_bit_width, - quant_weight.mantissa_bit_width) + quant_weight.mantissa_bit_width, + module.is_ocp, + module.is_fnuz) else: self.symbolic_kwargs = None @@ -421,18 +433,19 @@ def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_ class FloatQCDQCastActQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin, ABC): handled_layer = ActFloatQuantProxyFromInjector - def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): + def quantize_symbolic_kwargs( + cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz): # compute axis before redefining scale axis = cls.quant_axis(scale) scale = to_0dim_if_scalar(scale.flatten()) zp = to_0dim_if_scalar(zero_point.flatten()) # expand_as must go after 0-dim check zp = zp.expand_as(scale) - zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz, zp) if cls.itemize_quantize_scalar_params: scale = to_item_if_0dim(scale) zp = to_item_if_0dim(zp) - dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width) + dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz) return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} def prepare_for_export(self, module): @@ -457,12 +470,16 @@ def prepare_for_export(self, module): scale, module.zero_point(), module.exponent_bit_width(), - module.mantissa_bit_width()) + module.mantissa_bit_width(), + module.is_ocp, + module.is_fnuz) self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( scale, module.zero_point(), module.exponent_bit_width(), - module.mantissa_bit_width()) + module.mantissa_bit_width(), + module.is_ocp, + module.is_fnuz) self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( module.is_narrow_range, module.is_signed, diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 9f4184071..97af39049 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -53,20 +53,8 @@ def validate(self, module): class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC): - def is_ocp(self, module): - is_e4m3 = module.mantissa_bit_width() == 3 and module.exponent_bit_width() == 4 - - is_ocp_e4m3 = is_e4m3 and module.inf_values() is None and module.nan_values() == (('111',)) - - is_e5m2 = module.mantissa_bit_width() == 5 and module.exponent_bit_width() == 2 - - is_ocp_e5m2 = is_e5m2 and module.inf_values() == ( - ('00',)) and module.nan_values() == ('01', '11', '10') - - return is_ocp_e4m3 or is_ocp_e5m2 - def validate(self, module): - assert self.is_ocp(module), 'Only OCP Standard is supported for FP8 export' + assert module.is_ocp or module.is_fnuz, 'Only OCP/FNUZ Standard are supported for FP8 export' class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC): diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 835caf647..b59a37696 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -62,6 +62,28 @@ def nan_values(self): nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values return nan_values + @property + def is_ocp(self): + is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4 + is_ocp_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values() == (('111',)) + + is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2 + is_ocp_e5m2 = is_e5m2 and self.inf_values() == ( + ('00',)) and self.nan_values() == ('01', '11', '10') + + return is_ocp_e4m3 or is_ocp_e5m2 + + @property + def is_fnuz(self): + is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4 + is_fnuz_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values( + ) is None and self.exponent_bias() == 8 + + is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2 + is_fnuz_e5m2 = is_e5m2 and self.inf_values() is None and self.nan_values( + ) is None and self.exponent_bias() == 16 + return is_fnuz_e4m3 or is_fnuz_e5m2 + def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 4151bc555..28f1e1b5e 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -37,6 +37,28 @@ def inf_values(self, force_eval=True): def nan_values(self, force_eval=True): return self.retrieve_attribute('nan_values', force_eval) + @property + def is_ocp(self): + is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4 + is_ocp_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values() == (('111',)) + + is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2 + is_ocp_e5m2 = is_e5m2 and self.inf_values() == ( + ('00',)) and self.nan_values() == ('01', '11', '10') + + return is_ocp_e4m3 or is_ocp_e5m2 + + @property + def is_fnuz(self): + is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4 + is_fnuz_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values( + ) is None and self.exponent_bias() == 8 + + is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2 + is_fnuz_e5m2 = is_e5m2 and self.inf_values() is None and self.nan_values( + ) is None and self.exponent_bias() == 16 + return is_fnuz_e4m3 or is_fnuz_e5m2 + def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]: out = x if self.fused_activation_quant_proxy is not None: diff --git a/src/brevitas/quant/experimental/float_quant_fnuz.py b/src/brevitas/quant/experimental/float_quant_fnuz.py new file mode 100644 index 000000000..7d7035cb6 --- /dev/null +++ b/src/brevitas/quant/experimental/float_quant_fnuz.py @@ -0,0 +1,163 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from dependencies import value + +from brevitas.quant.base import MSESymmetricScale +from brevitas.quant.experimental.float_base import FloatActBase +from brevitas.quant.experimental.float_base import FloatWeightBase +from brevitas.quant.experimental.float_base import Fp8e4m3Mixin +from brevitas.quant.experimental.float_base import Fp8e5m2Mixin +from brevitas.quant.experimental.float_base import ScaledFloatActBase +from brevitas.quant.experimental.float_base import ScaledFloatWeightBase + + +class Fp8e4m3FNUZMixin(Fp8e4m3Mixin): + nan_values = None + inf_values = None + + @value + def exponent_bias(exponent_bit_width): + return 2 ** (exponent_bit_width - 1) + + +class Fp8e5m2FNUZMixin(Fp8e5m2Mixin): + nan_values = None + inf_values = None + + @value + def exponent_bias(exponent_bit_width): + return 2 ** (exponent_bit_width - 1) + + +class Fp8e4m3FNUZWeight(Fp8e4m3FNUZMixin, FloatWeightBase): + """ + FP8 signed E3M4 weight quantizer. + """ + pass + + +class Fp8e5m2FNUZWeight(Fp8e5m2FNUZMixin, FloatWeightBase): + """ + FP8 signed E5M2 weight quantizer. + """ + pass + + +class Fp8e4m3FNUZAct(Fp8e4m3FNUZMixin, FloatActBase): + """ + FP8 signed E4M3 activation quantizer. + """ + pass + + +class Fp8e5m2FNUZAct(Fp8e5m2FNUZMixin, FloatActBase): + """ + FP8 signed E5M2 activation quantizer. + """ + pass + + +class Fp8e4m3FNUZWeightPerTensorFloat(Fp8e4m3FNUZMixin, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2FNUZWeightPerTensorFloat(Fp8e5m2FNUZMixin, ScaledFloatWeightBase): + """ + FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3FNUZActPerTensorFloat(Fp8e4m3FNUZMixin, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2FNUZActPerTensorFloat(Fp8e5m2FNUZMixin, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3FNUZWeightPerChannelFloat(Fp8e4m3FNUZMixin, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e5m2FNUZWeightPerChannelFloat(Fp8e5m2FNUZMixin, ScaledFloatWeightBase): + """ + FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e4m3FNUZActPerChannelFloat2d(Fp8e4m3FNUZMixin, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e5m2FNUZActPerChannelFloat2d(Fp8e5m2FNUZMixin, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e4m3FNUZActPerTensorFloatMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2FNUZActPerTensorFloatMSE(Fp8e5m2FNUZMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3FNUZActPerChannelFloat2dMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e5m2FNUZActPerChannelFloat2dMSE(Fp8e5m2FNUZMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e4m3FNUZWeightPerChannelFloatMSE(Fp8e4m3FNUZMixin, + MSESymmetricScale, + ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e4m3FNUZWeightPerTensorFloatMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, + ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. + """ + scaling_per_output_channel = False diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index ceaf789f4..b863bf618 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -62,7 +62,7 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint), 'weight_symmetric_activation_dynamic_asymmetric_per_tensor_float': (Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat), - 'fp8_per_tensor_float': (Fp8e4m3OCPWeightPerTensorFloat, Fp8e4m3OCPActPerTensorFloat)} + 'fp8_ocp_per_tensor_float': (Fp8e4m3OCPWeightPerTensorFloat, Fp8e4m3OCPActPerTensorFloat)} LSTM_QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat), diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index 9bad4e89c..88022aaa6 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -27,7 +27,8 @@ def case_quant_wbiol( set_case_id(request.node.callspec.id, QuantWBIOLCases.case_quant_wbiol) weight_quant, io_quant = quantizers - if weight_quant == Fp8e4m3OCPWeightPerTensorFloat: + is_fp8 = weight_quant == Fp8e4m3OCPWeightPerTensorFloat + if is_fp8: if weight_bit_width < 8 or input_bit_width < 8 or output_bit_width < 8: pytest.skip('FP8 export requires total bitwidth equal to 8') torch.use_deterministic_algorithms(False) @@ -40,9 +41,9 @@ def case_quant_wbiol( layer_kwargs = { 'in_channels': IN_CH, 'out_channels': OUT_CH, 'kernel_size': KERNEL_SIZE} - bias_quantizer = None if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else Int32Bias + bias_quantizer = None if is_fp8 else Int32Bias # Required because of numpy error with FP8 data type. Export iself works fine. - return_quant_tensor = False if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else True + return_quant_tensor = False if is_fp8 else True class Model(nn.Module):