diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 65af11e61..929024c63 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -10,8 +10,7 @@ from brevitas.core.function_wrapper import RoundSte from brevitas.core.scaling import ConstScaling from brevitas.core.utils import StatelessBuffer -from brevitas.function.ops import max_float -from brevitas.function.ops_ste import floor_ste +from brevitas.utils.torch_utils import float_internal_scale class FloatQuant(brevitas.jit.ScriptModule): @@ -64,13 +63,6 @@ def __init__( dtype = torch.get_default_dtype() self.eps = torch.finfo(dtype).tiny - @brevitas.jit.script_method - def internal_scale(self, x): - internal_scale = floor_ste(torch.log2(torch.abs(x) + self.eps)) - self.mantissa_bit_width() - internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min()) - internal_scale = torch.exp2(internal_scale) - return internal_scale - @brevitas.jit.script_method def quantize(self, x: torch.Tensor): scaling_impl_value = self.scaling_impl(x) @@ -78,7 +70,8 @@ def quantize(self, x: torch.Tensor): self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) scale = scaling_impl_value / float_scaling_impl_value scaled_x = x / scale - internal_scale = self.internal_scale(scaled_x) + internal_scale = float_internal_scale( + scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min()) val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale) return val_fp_quant, scale diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index 6136a4cdc..bf3f69ed4 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -4,6 +4,7 @@ from abc import ABC from abc import abstractmethod import math +from warnings import warn import torch from torch import Tensor @@ -12,7 +13,8 @@ from brevitas.function.ops import max_int from brevitas.function.ops import min_int -__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin'] +__all__ = [ + 'BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin'] class BaseHandler(Module, ABC): @@ -38,6 +40,13 @@ def quant_axis(cls, scale): return None +class FloatClipMixin(ABC): + + @classmethod + def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width): + return None + + class ClipMixin(ABC): @classmethod @@ -112,6 +121,18 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor): return -cls.validate_scalar_int_exponent(scale) +class FloatZeroPointHandlerMixin(ABC): + + @classmethod + def zero_point_with_dtype(cls, exponent_bit_width, mantissa_bit_width, zero_point): + if exponent_bit_width == 4 and mantissa_bit_width == 3: + return zero_point.type(torch.float8_e4m3fn) + elif exponent_bit_width == 5 and mantissa_bit_width == 2: + return zero_point.type(torch.float8_e5m2) + else: + return zero_point.type(torch.float32) + + class ZeroPointHandlerMixin(ABC): @classmethod diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index c4659ac87..9a91d1f5a 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -10,16 +10,21 @@ from brevitas.export.common import to_0dim_if_scalar from brevitas.export.common import to_item_if_0dim +from brevitas.proxy import ActFloatQuantProxyFromInjector from brevitas.proxy import ActQuantProxyFromInjector from brevitas.proxy import BiasQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector +from brevitas.proxy import WeightFloatQuantProxyFromInjector from brevitas.proxy import WeightQuantProxyFromInjector +from brevitas.proxy.float_parameter_quant import BiasFloatQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector from .base import BitWidthHandlerMixin from .base import ClipMixin +from .base import FloatClipMixin +from .base import FloatZeroPointHandlerMixin from .base import QuantAxisMixin from .base import ZeroPointHandlerMixin @@ -66,6 +71,25 @@ def clip_fn(self, x, min_val, max_val): pass +class FloatQMixin(ABC): + + @abstractmethod + def quantize_fn(self, x, scale, zero_point, dtype, axis): + pass + + @classmethod + def signed_dtype(cls, exponent_bit_width, mantissa_bit_width): + if exponent_bit_width is None or mantissa_bit_width is None: + return None + if exponent_bit_width == 4 and mantissa_bit_width == 3: + dtype = torch.float8_e4m3fn + elif exponent_bit_width == 5 and mantissa_bit_width == 2: + dtype = torch.float8_e5m2 + else: + dtype = torch.float32 + return dtype + + class QMixin(BitWidthHandlerMixin, ABC): @classmethod @@ -110,6 +134,33 @@ def quantize_fn(self, x, dtype): pass +class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, + FloatClipMixin, + FloatZeroPointHandlerMixin, + CDQCastMixin, + ABC): + + def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): + scale_orig_shape = scale.shape + axis = cls.quant_axis(scale) + if cls.flatten_dequantize_params: + scale = scale.flatten() + scale = to_0dim_if_scalar(scale) + if cls.flatten_dequantize_params: + zero_point = zero_point.flatten() + zp = to_0dim_if_scalar(zero_point) + zp = zp.expand_as(scale) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) + return { + 'scale': scale, + 'zero_point': zp, + 'axis': axis, + # We save only the scale original shape + # as zero-point is being expanded to the same + # size as the scale + 'scale_orig_shape': scale_orig_shape} + + class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): @@ -133,6 +184,122 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): 'scale_orig_shape': scale_orig_shape} +class FloatQCDQCastWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): + handled_layer = WeightFloatQuantProxyFromInjector + + def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): + # compute axis before redefining scale + axis = cls.quant_axis(scale) + scale = to_0dim_if_scalar(scale.flatten()) + zp = to_0dim_if_scalar(zero_point.flatten()) + # expand_as must go after 0-dim check + zp = zp.expand_as(scale) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) + if cls.itemize_quantize_scalar_params: + scale = to_item_if_0dim(scale) + zp = to_item_if_0dim(zp) + dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width) + return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} + + def prepare_quantize_from_floating_point(self, module): + quant_weight = module.tracked_module_list[0].quant_weight() + scale = quant_weight.scale + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( + scale, + quant_weight.zero_point, + quant_weight.exponent_bit_width, + quant_weight.mantissa_bit_width) + + def prepare_quantize_from_minifloat(self, module): + raise NotImplementedError + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + if self._export_q_node: + self.prepare_quantize_from_floating_point(module) + else: + self.prepare_quantize_from_minifloat(module) + # Get the first quant weight as representative + quant_weight = module.tracked_module_list[0].quant_weight() + + # (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype + # of the scale and we cast it to float32. + # The original dtype is then restored during the forward pass + scale = quant_weight.scale + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + + self.symbolic_kwargs['exponent_bit_width'] = quant_weight.exponent_bit_width + self.symbolic_kwargs['mantissa_bit_width'] = quant_weight.mantissa_bit_width + self.symbolic_kwargs['exponent_bias'] = quant_weight.exponent_bias + self.symbolic_kwargs['saturating'] = quant_weight.saturating + self.symbolic_kwargs['inf_values'] = quant_weight.inf_values + self.symbolic_kwargs['nan_values'] = quant_weight.nan_values + self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( + module.is_narrow_range, + module.is_signed, + quant_weight.exponent_bit_width, + quant_weight.mantissa_bit_width) + self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( + scale, + quant_weight.zero_point, + quant_weight.exponent_bit_width, + quant_weight.mantissa_bit_width) + else: + self.symbolic_kwargs = None + + def quantize_from_floating_point(self, x: Tensor): + # Workaround for equal_cpu RuntimeError + quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + # Before quantization, cast input to float32 + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, torch.float32) + x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) + return x + + def quantize_from_minifloat(self, x: Tensor): + raise NotImplementedError + + def symbolic_execution(self, x: Tensor): + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + + # Copy dict to allow for popping kwargs even on shared quantizers + dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs']) + scale = dequantize_symbolic_kwargs['scale'] + zero_point = dequantize_symbolic_kwargs['zero_point'] + + if self._export_q_node: + x = self.quantize_from_floating_point(x) + else: + x = self.quantize_from_minifloat(x) + clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] + exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] + mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + exponent_bias = self.symbolic_kwargs['exponent_bias'] + saturating = self.symbolic_kwargs['saturating'] + inf_values = self.symbolic_kwargs['inf_values'] + nan_values = self.symbolic_kwargs['nan_values'] + scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') + # Workaround to trick the tracer into believing all return values are used + self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width, exponent_bias) + if clip_symbolic_kwargs is not None: + x = self.clip_fn(x, *clip_symbolic_kwargs.values()) + x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # After dequantization, cast both input and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) + return x, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values + + class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin): handled_layer = WeightQuantProxyFromInjector @@ -251,6 +418,96 @@ def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_ return super().symbolic_execution(x) +class FloatQCDQCastActQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin, ABC): + handled_layer = ActFloatQuantProxyFromInjector + + def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): + # compute axis before redefining scale + axis = cls.quant_axis(scale) + scale = to_0dim_if_scalar(scale.flatten()) + zp = to_0dim_if_scalar(zero_point.flatten()) + # expand_as must go after 0-dim check + zp = zp.expand_as(scale) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) + if cls.itemize_quantize_scalar_params: + scale = to_item_if_0dim(scale) + zp = to_item_if_0dim(zp) + dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width) + return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width() + self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width() + self.symbolic_kwargs['exponent_bias'] = module.exponent_bias() + self.symbolic_kwargs['saturating'] = module.saturating() + self.symbolic_kwargs['inf_values'] = module.inf_values() + self.symbolic_kwargs['nan_values'] = module.nan_values() + + # (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype + # of the scale and we cast it to float32. + # The original dtype is then restored during the forward pass + scale = module.scale() + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + + self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( + scale, + module.zero_point(), + module.exponent_bit_width(), + module.mantissa_bit_width()) + self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( + scale, + module.zero_point(), + module.exponent_bit_width(), + module.mantissa_bit_width()) + self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( + module.is_narrow_range, + module.is_signed, + module.exponent_bit_width(), + module.mantissa_bit_width()) + + else: + self.symbolic_kwargs = None + + def symbolic_execution(self, x: Tensor): + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + + # Copy dict to allow for popping kwargs even on shared quantizers + dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs']) + scale = dequantize_symbolic_kwargs['scale'] + zero_point = dequantize_symbolic_kwargs['zero_point'] + scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') + + quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] + exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] + mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + exponent_bias = self.symbolic_kwargs['exponent_bias'] + saturating = self.symbolic_kwargs['saturating'] + inf_values = self.symbolic_kwargs['inf_values'] + nan_values = self.symbolic_kwargs['nan_values'] + + self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width, exponent_bias) + # If original dtype of the input is (b)float16, cast the input to float32 + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: + x = self.cast_fn(x, torch.float32) + x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) + if clip_symbolic_kwargs is not None: + x = self.clip_fn(x, *clip_symbolic_kwargs.values()) + x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # After dequantization, cast both output and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) + return x, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values + + class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC): handled_layer = ActQuantProxyFromInjector @@ -357,6 +614,82 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width +class FloatCDQCastBiasQuantProxyHandlerMixin(DQCastMixin, + QuantAxisMixin, + FloatZeroPointHandlerMixin, + ABC): + # TODO: We do not have any bias quantizer so this is not wired to anything. + # Currently we do not support Minifloat -> DQ export for minifloat. + # This has to be rewritten to be QDQ + handled_layer = BiasFloatQuantProxyFromInjector + + def validate(self, module): + if module.bit_width() is not None: + assert module.bit_width() > 1., 'Binary quant not supported' + assert module.is_signed, 'Unsigned bias not supported.' + assert module.rounding_mode == 'ROUND', 'Only round to nearest even supported.' + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + int_biases = { + tm.bias.data_ptr(): tm.quant_bias().minifloat(float_datatype=False) + for tm in module.tracked_module_list} + self.symbolic_kwargs = { + 'int_biases': int_biases, + 'scale': module.scale(), + 'zero_point': module.zero_point(), + 'exponent_bit_width': module.exponent_bit_width(), + 'mantissa_bit_width': module.mantissa_bit_width(), + 'exponent_bias': module.exponent_bias(), + 'saturating': module.saturating(), + 'inf_values': module.inf_values(), + 'nan_values': module.nan_values()} + + else: + self.symbolic_kwargs = None + + def symbolic_execution(self, x: Tensor, input_scale=None): + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + int_bias = self.symbolic_kwargs['int_biases'][x.data_ptr()] + scale = self.symbolic_kwargs['scale'] + zero_point = self.symbolic_kwargs['zero_point'] + exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] + mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + exponent_bias = self.symbolic_kwargs['exponent_bias'] + saturating = self.symbolic_kwargs['saturating'] + inf_values = self.symbolic_kwargs['inf_values'] + nan_values = self.symbolic_kwargs['nan_values'] + + assert scale is not None or input_scale is not None, 'Input scale required for bias export' + if input_scale is not None: + scale = input_scale + scale_orig_shape = scale.shape + + quant_axis = self.quant_axis(scale) + if self.flatten_dequantize_params: + scale = scale.flatten() + zero_point = zero_point.flatten() + scale = to_0dim_if_scalar(scale) + zero_point = to_0dim_if_scalar(zero_point).expand_as(scale) + zero_point = self.zero_point_with_dtype( + True, exponent_bit_width, mantissa_bit_width, zero_point) # assume signed is True + # If original dtype of scale is (b)float16, store the original dtype + # and cast the scale to float32 + scale_dtype = scale.dtype + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + scale = self.cast_fn(scale, torch.float32) + y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis) + # After dequantization, cast both output and scale to the correct dtype + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + y = self.cast_fn(y, scale_dtype) + scale = self.cast_fn(scale, scale_dtype) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) + return y, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values + + class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC): handled_layer = BiasQuantProxyFromInjector @@ -380,19 +713,17 @@ def prepare_for_export(self, module): else: self.symbolic_kwargs = None - def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): + def symbolic_execution(self, x: Tensor, input_scale=None): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' int_bias = self.symbolic_kwargs['int_biases'][x.data_ptr()] scale = self.symbolic_kwargs['scale'] bit_width = self.symbolic_kwargs['bit_width'] zero_point = self.symbolic_kwargs['zero_point'] assert scale is not None or input_scale is not None, 'Input scale required for bias export' - assert bit_width is not None or input_bit_width is not None, 'Input bit width required for bias export' if input_scale is not None: scale = input_scale scale_orig_shape = scale.shape - if input_bit_width is not None: - bit_width = input_bit_width + quant_axis = self.quant_axis(scale) if self.flatten_dequantize_params: scale = scale.flatten() diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index 1bacb461e..ae3270cc9 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -30,6 +30,43 @@ from ..manager import ExportContext +# workaround for fp8 not having many operators implemented +class PatchFp8Ops(): + + def __init__(self): + self.lib = None + + def __enter__(self): + if torch_version >= version.parse('2.1.0'): + self.lib = torch.library.Library("aten", "IMPL") + + def equal_cpu(self, other): + if (isinstance(self, Tensor) and + self.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)) or ( + isinstance(other, Tensor) and + other.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)): + self = self.to(torch.float32) + other = other.to(torch.float32) + return torch.equal(self, other) + else: + res = True + if not isinstance(self, Tensor): + self = torch.tensor(self) + if not isinstance(other, Tensor): + other = torch.tensor(other) + if self.dim() > 0: + for x, y in zip(self.flatten(), other.flatten()): + res &= x == y + else: + res = self.item() == other.item() + return torch.tensor([res]) + + self.lib.impl("equal", equal_cpu, "CPU") + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.lib = None + + class ONNXBaseManager(BaseManager, ABC): model_transforms = [] @@ -127,7 +164,9 @@ def export_onnx( else: model_bytes = BytesIO() export_target = model_bytes - torch.onnx.export(module, args, export_target, **onnx_export_kwargs) + + with PatchFp8Ops(): + torch.onnx.export(module, args, export_target, **onnx_export_kwargs) # restore the model to previous properties module.apply(lambda m: _restore_act_caching_mode(m)) diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index a8f3c507b..9f4184071 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC +from warnings import warn import torch @@ -10,6 +11,9 @@ from brevitas.export.common.handler.qcdq import DQCastMixin from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DynamicQMixin +from brevitas.export.common.handler.qcdq import FloatQCDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import FloatQCDQCastWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import FloatQMixin from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import \ @@ -47,12 +51,47 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' +class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC): + + def is_ocp(self, module): + is_e4m3 = module.mantissa_bit_width() == 3 and module.exponent_bit_width() == 4 + + is_ocp_e4m3 = is_e4m3 and module.inf_values() is None and module.nan_values() == (('111',)) + + is_e5m2 = module.mantissa_bit_width() == 5 and module.exponent_bit_width() == 2 + + is_ocp_e5m2 = is_e5m2 and module.inf_values() == ( + ('00',)) and module.nan_values() == ('01', '11', '10') + + return is_ocp_e4m3 or is_ocp_e5m2 + + def validate(self, module): + assert self.is_ocp(module), 'Only OCP Standard is supported for FP8 export' + + +class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC): + + def clip_fn(self, x, min_val, max_val): + raise NotImplementedError + + class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) +class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): + + def validate(self, module): + if getattr(self, '_export_q_node', True): + assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + super().validate(module) + + def quantize_fn(self, x, scale, zero_point, dtype, axis): + return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) + + class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC): @classmethod @@ -112,6 +151,12 @@ def quantize_fn(self, x, dtype): return DynamicQuantizeLinearFn.apply(x, dtype) +class StdFloatQCDQCastONNXWeightQuantProxyHandler(StdFloatQCDQCastONNXMixin, + FloatQCDQCastWeightQuantProxyHandlerMixin, + ONNXBaseHandler): + _export_q_node = False + + class StdQCDQCastONNXWeightQuantProxyHandler(StdQCDQCastONNXMixin, QCDQCastWeightQuantProxyHandlerMixin, ONNXBaseHandler): @@ -130,6 +175,12 @@ class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( _export_q_node = False +class StdFloatQCDQCastONNXActQuantProxyHandler(StdFloatQCDQCastONNXMixin, + FloatQCDQCastActQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, QCDQCastActQuantProxyHandlerMixin, ONNXBaseHandler): diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index b1b05f4ad..e43d97e6d 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -17,6 +17,8 @@ from ..manager import StdONNXBaseManager from .handler import StdCDQCastONNXBiasQuantProxyHandler from .handler import StdDynamicQDQCastONNXActQuantProxyHandler +from .handler import StdFloatQCDQCastONNXActQuantProxyHandler +from .handler import StdFloatQCDQCastONNXWeightQuantProxyHandler from .handler import StdQCDQCastONNXActQuantProxyHandler from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler @@ -36,8 +38,10 @@ class StdQCDQONNXManager(StdONNXBaseManager): handlers = [ StdQCDQCastONNXWeightQuantProxyHandler, + StdFloatQCDQCastONNXWeightQuantProxyHandler, StdCDQCastONNXBiasQuantProxyHandler, StdQCDQCastONNXActQuantProxyHandler, + StdFloatQCDQCastONNXActQuantProxyHandler, StdQCDQCastONNXDecoupledWeightQuantProxyHandler, StdDynamicQDQCastONNXActQuantProxyHandler, StdQCDQCastONNXTruncQuantProxyHandler, diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 10717774c..6751ab69c 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -5,13 +5,11 @@ Implementation of various core operations often performed as part of quantization. The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler. """ -from typing import List, Optional, Tuple import torch from torch import Tensor import brevitas -from brevitas.utils.float_quant_utils import get_minifloat_value @brevitas.jit.script diff --git a/src/brevitas/proxy/__init__.py b/src/brevitas/proxy/__init__.py index 57770749d..ebdc6403c 100644 --- a/src/brevitas/proxy/__init__.py +++ b/src/brevitas/proxy/__init__.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from .float_parameter_quant import WeightFloatQuantProxyFromInjector +from .float_runtime_quant import ActFloatQuantProxyFromInjector from .parameter_quant import BiasQuantProxyFromInjector from .parameter_quant import BiasQuantProxyFromInjectorBase from .parameter_quant import DecoupledWeightQuantProxyFromInjector diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 5fc1f2411..4151bc555 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -14,43 +14,28 @@ class ActFloatQuantProxyFromInjector(ActQuantProxyFromInjectorBase): def scale(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.scale - elif self._cached_act is not None: - return self._cached_act.scale - elif self._cached_act is None: - return None + return self.retrieve_attribute('scale', force_eval) def zero_point(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.zero_point - elif self._cached_act is not None: - return self._cached_act.zero_point - elif self._cached_act is None: - return None + return self.retrieve_attribute('zero_point', force_eval) - def bit_width(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.bit_width - elif self._cached_act is not None: - return self._cached_act.bit_width - elif self._cached_act is None: - return None + def exponent_bit_width(self, force_eval=True): + return self.retrieve_attribute('exponent_bit_width', force_eval) + + def mantissa_bit_width(self, force_eval=True): + return self.retrieve_attribute('mantissa_bit_width', force_eval) + + def exponent_bias(self, force_eval=True): + return self.retrieve_attribute('exponent_bias', force_eval) + + def saturating(self, force_eval=True): + return self.retrieve_attribute('saturating', force_eval) + + def inf_values(self, force_eval=True): + return self.retrieve_attribute('inf_values', force_eval) + + def nan_values(self, force_eval=True): + return self.retrieve_attribute('nan_values', force_eval) def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]: out = x @@ -68,7 +53,8 @@ def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuan y = self.fused_activation_quant_proxy(y) # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): + # We exclude the last two values (inf_values and nan_values) + if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 2457298b1..4ef93cad6 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -95,6 +95,23 @@ def __init__(self, quant_layer, quant_injector): self.cache_inference_quant_act = False self.cache_quant_io_metadata_only = True + def internal_forward(self, force_eval): + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out + + def retrieve_attribute(self, attribute, force_eval): + if self.is_quant_enabled: + out = self.internal_forward(force_eval) + return getattr(out, attribute) + elif self._cached_act is not None: + return getattr(self._cached_act, attribute) + elif self._cached_act is None: + return None + @property def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant @@ -132,43 +149,13 @@ def init_tensor_quant(self): class ActQuantProxyFromInjector(ActQuantProxyFromInjectorBase): def scale(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.scale - elif self._cached_act is not None: - return self._cached_act.scale - elif self._cached_act is None: - return None + return self.retrieve_attribute('scale', force_eval) def zero_point(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.zero_point - elif self._cached_act is not None: - return self._cached_act.zero_point - elif self._cached_act is None: - return None + return self.retrieve_attribute('zero_point', force_eval) def bit_width(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.bit_width - elif self._cached_act is not None: - return self._cached_act.bit_width - elif self._cached_act is None: - return None + return self.retrieve_attribute('bit_width', force_eval) def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]: out = x diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 46e46e41e..1b7191037 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -7,10 +7,8 @@ from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector from brevitas.inject import value -from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector -from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector -from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector -from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.proxy import ActFloatQuantProxyFromInjector +from brevitas.proxy import WeightFloatQuantProxyFromInjector from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index c2bb99900..b06466d2d 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -3,10 +3,10 @@ import torch -from brevitas.function.ops_ste import round_ste from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import FloatQuantTensorBase from brevitas.quant_tensor import QuantTensor +from brevitas.utils.torch_utils import float_internal_scale from .float_torch_handler import FLOAT_QUANT_TENSOR_FN_HANDLER from .torch_handler import QUANT_TENSOR_FN_HANDLER @@ -94,13 +94,13 @@ def tensor(self): def _pre_round_float_value(self): value = self.value scale = self.scale - zero_point = self.zero_point if self.scale.dtype == torch.bfloat16: value = self.value.type(torch.float32) scale = self.scale.type(torch.float32) - zero_point = self.zero_point.type(torch.float32) minifloat_value = value / scale - minifloat_value = minifloat_value + zero_point + fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + minifloat_value = minifloat_value / int_scale return minifloat_value @property @@ -130,10 +130,13 @@ def device(self): return value_device def minifloat(self, float_datatype=True): + # TODO: Check if OCP and cast to proper data-type if matching assert float_datatype, "Minifloat quant returns only higher precision dtype" if self.is_valid: - float_value = self._pre_round_float_value + fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: raise RuntimeError(f"FloatQuantTensor not valid.") diff --git a/src/brevitas/quant_tensor/float_torch_handler.py b/src/brevitas/quant_tensor/float_torch_handler.py index 05386733a..7fb4507c1 100644 --- a/src/brevitas/quant_tensor/float_torch_handler.py +++ b/src/brevitas/quant_tensor/float_torch_handler.py @@ -1,14 +1,8 @@ import functools -import math -import warnings import torch import torch.nn.functional as F -from brevitas.function.ops import max_int -from brevitas.function.ops_ste import ceil_ste -from brevitas.utils.torch_utils import compute_channel_view_shape - FLOAT_QUANT_TENSOR_FN_HANDLER = {} diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 9392c001d..f7dbe9ef3 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -7,6 +7,9 @@ import torch from torch.nn import Sequential +import brevitas +from brevitas.function.ops_ste import floor_ste + class TupleSequential(Sequential): @@ -86,3 +89,14 @@ def compute_channel_view_shape(tensor: torch.Tensor, channel_dim: int): broadcast_shape = [1] * len(tensor.size()) broadcast_shape[channel_dim] = -1 return tuple(broadcast_shape) + + +@brevitas.jit.script +def float_internal_scale( + x: torch.Tensor, mantissa_bit_width: torch.Tensor, + fp_internal_scale_min: torch.Tensor) -> torch.Tensor: + + internal_scale = floor_ste(torch.log2(torch.abs(x))) - mantissa_bit_width + internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min) + internal_scale = torch.exp2(internal_scale) + return internal_scale diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 43c090e6c..021365239 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -13,6 +13,7 @@ from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling from brevitas.function.ops import max_float +from brevitas.utils.torch_utils import float_internal_scale from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format @@ -192,7 +193,8 @@ def test_inner_scale(inp, minifloat_format, scale): max_value = max_val if max_available_float is None else torch.min( max_value, max_available_float) # call internal scale - internal_scale = float_quant.internal_scale(scaled_inp) + internal_scale = float_internal_scale( + scaled_inp, float_quant.mantissa_bit_width(), float_quant.fp_internal_scale_min()) val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale) if signed: val_fp_quant = torch.clip(val_fp_quant, -1. * max_val, max_val) diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py new file mode 100644 index 000000000..b7e017484 --- /dev/null +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -0,0 +1,44 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from packaging import version +import pytest +import torch + +from brevitas import torch_version +from brevitas.export import export_onnx_qcdq +import brevitas.nn as qnn +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat +from tests.marker import jit_disabled_for_export + + +@jit_disabled_for_export() +def test_simple_fp8_export(): + if torch_version < version.parse('2.1.0'): + pytest.skip(f"OCP FP8 types not supported by {torch_version}") + + model = qnn.QuantLinear(3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat) + export_onnx_qcdq(model, torch.randn(1, 3), 'weight_fp8.onnx', export_weight_q_node=True) + assert True + + +@jit_disabled_for_export() +def test_fp8_export_activation(): + if torch_version < version.parse('2.1.0'): + pytest.skip(f"OCP FP8 types not supported by {torch_version}") + + model = qnn.QuantLinear(3, 16, input_quant=Fp8e4m3OCPActPerTensorFloat) + export_onnx_qcdq(model, torch.randn(1, 3), 'act_fp8.onnx', export_weight_q_node=True) + assert True + + +@jit_disabled_for_export() +def test_fp8_export_export_activation(): + if torch_version < version.parse('2.1.0'): + pytest.skip(f"OCP FP8 types not supported by {torch_version}") + + model = qnn.QuantLinear( + 3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat) + export_onnx_qcdq(model, torch.randn(1, 3), 'weight_act_fp8.onnx', export_weight_q_node=True) + assert True diff --git a/tests/brevitas_ort/__init__.py b/tests/brevitas_ort/__init__.py index 78315f9c7..b10a7efee 100644 --- a/tests/brevitas_ort/__init__.py +++ b/tests/brevitas_ort/__init__.py @@ -1,11 +1,2 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause - -try: - import torch - - # Avoid fast algorithms that might introduce extra error during fake-quantization - torch.use_deterministic_algorithms(True) -except: - # Introduced in 1.8.0 - pass diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 4d1e16679..ceaf789f4 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -18,6 +18,8 @@ from brevitas.nn import QuantConvTranspose2d from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantLinear +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint @@ -59,7 +61,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): 'symmetric_per_channel_fixed_point': (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint), 'weight_symmetric_activation_dynamic_asymmetric_per_tensor_float': - (Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat)} + (Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat), + 'fp8_per_tensor_float': (Fp8e4m3OCPWeightPerTensorFloat, Fp8e4m3OCPActPerTensorFloat)} LSTM_QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat), @@ -120,7 +123,14 @@ def recursive_allclose(ort_output, brevitas_output, tolerance): def is_brevitas_ort_close( - model, np_input, export_name, export_type, tolerance=None, first_output_only=False): + model, + np_input, + export_name, + export_type, + tolerance=None, + first_output_only=False, + onnx_opset=14, + export_q_weight=False): input_t = torch.from_numpy(np_input) with torch.no_grad(): brevitas_output = model(input_t) @@ -143,9 +153,12 @@ def is_brevitas_ort_close( ort_output = odict[exported_model.graph.output[0].name] else: if export_type == 'qcdq': - export_onnx_qcdq(model, input_t, export_path=export_name) - elif export_type == 'qcdq_opset14': - export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name) + export_onnx_qcdq( + model, + input_t, + export_path=export_name, + export_weight_q_node=export_q_weight, + opset_version=onnx_opset) elif export_type == 'qonnx_opset14': export_qonnx(model, input_t, opset_version=14, export_path=export_name) else: diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index 7361c1231..9bad4e89c 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -27,12 +27,23 @@ def case_quant_wbiol( set_case_id(request.node.callspec.id, QuantWBIOLCases.case_quant_wbiol) weight_quant, io_quant = quantizers + if weight_quant == Fp8e4m3OCPWeightPerTensorFloat: + if weight_bit_width < 8 or input_bit_width < 8 or output_bit_width < 8: + pytest.skip('FP8 export requires total bitwidth equal to 8') + torch.use_deterministic_algorithms(False) + else: + torch.use_deterministic_algorithms(True) + if impl is QuantLinear: layer_kwargs = {'in_features': IN_CH, 'out_features': OUT_CH} else: layer_kwargs = { 'in_channels': IN_CH, 'out_channels': OUT_CH, 'kernel_size': KERNEL_SIZE} + bias_quantizer = None if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else Int32Bias + # Required because of numpy error with FP8 data type. Export iself works fine. + return_quant_tensor = False if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else True + class Model(nn.Module): def __init__(self): @@ -46,8 +57,8 @@ def __init__(self): weight_bit_width=weight_bit_width, input_bit_width=input_bit_width, output_bit_width=output_bit_width, - bias_quant=Int32Bias, - return_quant_tensor=True) + bias_quant=bias_quantizer, + return_quant_tensor=return_quant_tensor) self.conv.weight.data.uniform_(-0.01, 0.01) def forward(self, x): diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 2b7b6b1cf..693b9274d 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -3,13 +3,14 @@ from functools import reduce from operator import mul -import os +from packaging.version import parse import pytest from pytest_cases import get_case_id from pytest_cases import parametrize_with_cases import torch +from brevitas import torch_version from tests.marker import requires_pt_ge from .common import * @@ -20,7 +21,7 @@ @parametrize_with_cases('model', cases=QuantWBIOLCases) @pytest.mark.parametrize('export_type', ['qcdq', 'qonnx']) -@requires_pt_ge('1.8.1') +@requires_pt_ge('1.10') def test_ort_wbiol(model, export_type, current_cases): cases_generator_func = current_cases['model'][1] case_id = get_case_id(cases_generator_func) @@ -29,7 +30,8 @@ def test_ort_wbiol(model, export_type, current_cases): quantizer = case_id.split('-')[-6] o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - + onnx_opset = 14 + export_q_weight = False if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') if 'QuantLinear' in impl and 'asymmetric' in quantizer: @@ -37,6 +39,13 @@ def test_ort_wbiol(model, export_type, current_cases): if 'dynamic' in quantizer and ((o_bit_width != "o8" or i_bit_width != "i8") or export_type != "qcdq"): pytest.skip('Dynamic Act Quant supported only for 8bit and QCDQ export') + if export_type == 'qonnx' and 'fp8' in quantizer: + pytest.skip('FP8 export requires QCDQ') + if torch_version < parse('2.1') and 'fp8' in quantizer: + pytest.skip('FP8 requires PyTorch 2.1 or higher') + elif torch_version >= parse('2.1') and 'fp8' in quantizer: + onnx_opset = 19 + export_q_weight = True if impl in ('QuantLinear'): in_size = (1, IN_CH) @@ -55,11 +64,18 @@ def test_ort_wbiol(model, export_type, current_cases): model.eval() export_name = f'qcdq_qop_export_{case_id}.onnx' assert is_brevitas_ort_close( - model, inp, export_name, export_type, tolerance=INT_TOLERANCE, first_output_only=True) + model, + inp, + export_name, + export_type, + tolerance=INT_TOLERANCE, + first_output_only=True, + onnx_opset=onnx_opset, + export_q_weight=export_q_weight) @parametrize_with_cases('model', cases=QuantAvgPoolCases) -@requires_pt_ge('1.8.1') +@requires_pt_ge('1.10') def test_ort_avgpool(model, current_cases): in_size = (1, IN_CH, FEATURES, FEATURES) inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size) @@ -71,7 +87,7 @@ def test_ort_avgpool(model, current_cases): @parametrize_with_cases('model', cases=QuantRecurrentCases) -@pytest.mark.parametrize('export_type', ['qcdq_opset14', 'qonnx_opset14']) +@pytest.mark.parametrize('export_type', ['qcdq', 'qonnx_opset14']) @requires_pt_ge('1.10') def test_ort_lstm(model, export_type, current_cases): cases_generator_func = current_cases['model'][1]