diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index d90a4b976..ec0923fd1 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -39,12 +39,6 @@ def quant_axis(cls, scale): return i return None -class FloatClipMixin(ABC): - @classmethod - def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width): - warn("Not implemented for floating point") - return None - class ClipMixin(ABC): @classmethod diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 5bbc1e28e..572d2c241 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -21,7 +21,6 @@ from .base import BitWidthHandlerMixin from .base import ClipMixin -from .base import FloatClipMixin from .base import QuantAxisMixin from .base import ZeroPointHandlerMixin from .base import FloatZeroPointHandlerMixin @@ -129,7 +128,7 @@ def quantize_fn(self, x, dtype): pass -class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatClipMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): +class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed): scale_orig_shape = scale.shape @@ -221,22 +220,55 @@ def prepare_for_export(self, module): if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: scale = self.cast_fn(scale, torch.float32) - self.symbolic_kwargs['bit_width'] = quant_weight.bit_width - self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( - module.is_narrow_range, module.is_signed, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width) + self.symbolic_kwargs['exponent_bit_width'] = quant_weight.exponent_bit_width + self.symbolic_kwargs['mantissa_bit_width'] = quant_weight.mantissa_bit_width self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( scale, quant_weight.zero_point, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width, module.is_signed) else: self.symbolic_kwargs = None - def quantize_from_floating_point(self, x: Tensor): - raise NotImplementedError() + def quantize_from_floating_point(self, x: Tensor, zp): + quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + quantize_symbolic_kwargs['zero_point'] = zp + # Before quantization, cast input to float32 + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, torch.float32) + x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) + return x def quantize_from_integer(self, x: Tensor): - raise NotImplementedError() + int_weights = { + tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False) + for tm in module.tracked_module_list} + self.symbolic_kwargs['int_weights'] = int_weights def symbolic_execution(self, x: Tensor): - raise NotImplementedError() + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + + # Copy dict to allow for popping kwargs even on shared quantizers + dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs']) + scale = dequantize_symbolic_kwargs['scale'] + zero_point = dequantize_symbolic_kwargs['zero_point'] + + if self._export_q_node: + x = self.quantize_from_floating_point(x, zero_point) + else: + x = self.quantize_from_integer(x) + + exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] + mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') + # Workaround to trick the tracer into believing all return values are used + self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) + x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # After dequantization, cast both input and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) + return x, scale, zero_point, exponent_bit_width, mantissa_bit_width class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin): handled_layer = WeightQuantProxyFromInjector diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index da13526d1..9e5b62fd5 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -30,25 +30,6 @@ from warnings import warn - -class StdFloatDQCastONNXMixin(DQCastMixin, ABC): - - def dequantize_fn(self, x, scale, zero_point, axis): - raise NotImplementedError() - - def cast_fn(self, x, dtype): - raise NotImplementedError() - - @property - def flatten_dequantize_params(self): - raise NotImplementedError() - - @property - def itemize_quantize_scalar_params(self): - return False - - def validate(self, module): - raise NotImplementedError() class StdDQCastONNXMixin(DQCastMixin, ABC): @@ -69,6 +50,9 @@ def itemize_quantize_scalar_params(self): def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' +class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC): + def validate(self, module): + pass class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC): @@ -82,11 +66,10 @@ def clip_fn(self, x, min_val, max_val): class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): def validate(self, module): - warn("Needs to be implemented") pass def quantize_fn(self, x, scale, zero_point, dtype, axis): - raise NotImplementedError() + return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC):