From 13652969e2d51a35e250ab65e63e435dac5e88db Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 27 Feb 2024 15:17:56 +0000 Subject: [PATCH] Feat: functionalize QuantTensor --- src/brevitas/graph/calibrate.py | 1 - src/brevitas/nn/mixin/base.py | 4 +- src/brevitas/nn/quant_avg_pool.py | 44 +- src/brevitas/nn/quant_conv.py | 8 - src/brevitas/nn/quant_convtranspose.py | 27 +- src/brevitas/nn/quant_layer.py | 65 +-- src/brevitas/nn/quant_linear.py | 7 - src/brevitas/nn/quant_scale_bias.py | 23 +- src/brevitas/nn/utils.py | 12 +- src/brevitas/quant_tensor/__init__.py | 432 +----------------- .../quant_tensor/base_quant_tensor.py | 25 + src/brevitas/quant_tensor/int_quant_tensor.py | 420 +++++++++++++++++ src/brevitas/quant_tensor/torch_handler.py | 243 ++++++++++ src/brevitas/utils/torch_utils.py | 6 + tests/brevitas/nn/test_nn_quantizers.py | 9 +- 15 files changed, 763 insertions(+), 563 deletions(-) create mode 100644 src/brevitas/quant_tensor/base_quant_tensor.py create mode 100644 src/brevitas/quant_tensor/int_quant_tensor.py diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 8f690fc9b..2a93f1226 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -13,7 +13,6 @@ from brevitas.nn import QuantHardTanh from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL -from brevitas.nn.utils import compute_channel_view_shape from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 8327bc156..8810e2af2 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -17,9 +17,9 @@ from brevitas.common import ExportMixin from brevitas.inject import ExtendedInjector from brevitas.inject import Injector -from brevitas.nn.utils import compute_channel_view_shape from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor +from brevitas.utils.torch_utils import compute_channel_view_shape from .utils import filter_kwargs @@ -86,7 +86,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: self._set_global_is_quant_layer(False) if self.return_quant_tensor: - assert isinstance(quant_output, QuantTensor) + assert isinstance(quant_output, QuantTensor), 'QuantLayer is not correctly configured, check if warnings were raised' return quant_output else: return _unpack_quant_tensor(quant_output) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 554504908..1cc2cc5c2 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -47,38 +47,19 @@ def channelwise_separable(self) -> bool: def requires_export_handler(self): return True - @property - def _avg_scaling(self): - if isinstance(self.kernel_size, tuple): - return self.kernel_size[0] * self.kernel_size[1] - else: - return self.kernel_size * self.kernel_size - def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) if self.export_mode: return self.export_handler(_unpack_quant_tensor(x)) - if isinstance(x, QuantTensor): - x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) - if self.is_trunc_quant_enabled: - # remove avg scaling - rescaled_value = x.value * self._avg_scaling - x = x.set(value=rescaled_value) - x = x.set(bit_width=self.max_acc_bit_width(x.bit_width)) - x = self.trunc_quant(x) + if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: + y = AvgPool2d.forward(self, x) + y = self.trunc_quant(y) else: - assert not self.is_trunc_quant_enabled - x = super(TruncAvgPool2d, self).forward(x) - - return self.pack_output(x) + y = AvgPool2d.forward(self, _unpack_quant_tensor(x)) - def max_acc_bit_width(self, input_bit_width): - max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_uint_output = max_uint_input * self._avg_scaling - max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width + return self.pack_output(y) class TruncAdaptiveAvgPool2d(TruncMixin, QuantLayerMixin, AdaptiveAvgPool2d): @@ -130,18 +111,11 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._set_global_is_quant_layer(False) return out - if isinstance(x, QuantTensor): - y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) - k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) - if self.is_trunc_quant_enabled: - reduce_size = reduce(mul, k_size, 1) - rescaled_value = y.value * reduce_size # remove avg scaling - y = y.set(value=rescaled_value) - y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size)) - y = self.trunc_quant(y) + if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: + y = AdaptiveAvgPool2d.forward(self, x) + y = self.trunc_quant(y) else: - assert not self.is_trunc_quant_enabled - y = super(TruncAdaptiveAvgPool2d, self).forward(x) + y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x)) return self.pack_output(y) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 74912af67..7e920e158 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -109,14 +109,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option else: return self._conv_forward(x, quant_weight, quant_bias) - def max_acc_bit_width(self, input_bit_width, weight_bit_width): - max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) - group_size = self.in_channels // self.groups - max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size - max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width - class QuantConv2d(QuantWBIOL, Conv2d): diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index 75dd90378..3a5cdfe5a 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -103,7 +103,14 @@ def compute_output_padding(self, inp, output_size): def conv_transpose1d_zeros_pad( self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding): out = conv_transpose1d( - x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + x, + weight, + bias, + stride=self.stride, + padding=self.padding, + output_padding=output_padding, + groups=self.groups, + dilation=self.dilation) return out def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): @@ -200,7 +207,14 @@ def compute_output_padding(self, inp, output_size): def conv_transpose2d_zeros_pad( self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding): out = conv_transpose2d( - x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + x, + weight, + bias, + stride=self.stride, + padding=self.padding, + output_padding=output_padding, + groups=self.groups, + dilation=self.dilation) return out def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): @@ -298,7 +312,14 @@ def compute_output_padding(self, inp, output_size): def conv_transpose3d_zeros_pad( self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding): out = conv_transpose3d( - x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + x, + weight, + bias, + stride=self.stride, + padding=self.padding, + output_padding=output_padding, + groups=self.groups, + dilation=self.dilation) return out def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 8cd2e10a7..0a82afb9b 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -9,11 +9,10 @@ from torch import Tensor from torch.nn import Module -from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor +from brevitas.utils.torch_utils import compute_channel_view_shape from .mixin import * -from .utils import compute_channel_view_shape from .utils import merge_bn from .utils import rename_state_dict_by_prefix @@ -47,7 +46,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): quant_input = self.input_quant(input) # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(_unpack_quant_tensor(quant_input)) + out = self.export_handler(quant_input) self._set_global_is_quant_layer(False) return out out = self.act_quant(quant_input) @@ -121,7 +120,8 @@ def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Ten def quant_output_scale_impl( self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor): - output_scale_shape = compute_channel_view_shape(inp, channel_dim=1) + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) output_scale = quant_weight_scale.view(output_scale_shape) output_scale = output_scale * quant_input_scale.view(output_scale_shape) return output_scale @@ -140,16 +140,12 @@ def merge_bn_in(self, bn): merge_bn(self, bn, output_channel_dim=self.output_channel_dim) def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - output_scale = None - output_bit_width = None - output_zero_point = None - output_signed = None inp = self.unpack_input(inp) # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(_unpack_quant_tensor(inp)) + out = self.export_handler(inp) self._set_global_is_quant_layer(False) return out @@ -163,58 +159,15 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe self.output_quant.is_quant_enabled) and self.return_quant_tensor: raise RuntimeError("QuantLayer is not correctly configured") + output_scale = None if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): - output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width) output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale) - output_signed = quant_input.signed or quant_weight.signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale) - - output_tensor = self.inner_forward_impl( - _unpack_quant_tensor(quant_input), - _unpack_quant_tensor(quant_weight), - _unpack_quant_tensor(quant_bias)) - - if output_scale is not None: - if (isinstance(quant_bias, QuantTensor) and - quant_bias.scale.data_ptr() != output_scale.data_ptr()) or not isinstance( - quant_bias, QuantTensor): - channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 - output_scale_broadcast_shape = compute_channel_view_shape( - inp, channel_dim=channel_dim) - output_zero_point = -_unpack_quant_tensor(quant_bias).view( - output_scale_broadcast_shape) / output_scale - - if output_bit_width is not None and isinstance(quant_bias, QuantTensor): - output_bit_width = torch.where( - quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width) - output_bit_width = output_bit_width + 1 - else: - output_tensor = self.inner_forward_impl( - _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) - - if not self.output_quant.is_quant_enabled and self.return_quant_tensor: - if compute_output_quant_tensor: - if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): - raise RuntimeError( - "Computing zero point of output accumulator not supported yet.") - elif output_zero_point is None: - output_zero_point = quant_input.zero_point - - elif output_zero_point is None: - output_zero_point = torch.zeros(1).type_as(output_tensor) - - if compute_output_quant_tensor: - quant_output = QuantTensor( - output_tensor, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=self.training) else: - quant_output = output_tensor + quant_bias = None + output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias) - quant_output = self.output_quant(quant_output) + quant_output = self.output_quant(output_tensor) return self.pack_output(quant_output) diff --git a/src/brevitas/nn/quant_linear.py b/src/brevitas/nn/quant_linear.py index 46f3191b9..428343804 100644 --- a/src/brevitas/nn/quant_linear.py +++ b/src/brevitas/nn/quant_linear.py @@ -79,10 +79,3 @@ def quant_output_scale_impl( quant_weight_scale = quant_weight_scale.view(weight_broadcast_shape) quant_output_scale = linear(quant_input_scale, quant_weight_scale) return quant_output_scale - - def max_acc_bit_width(self, input_bit_width, weight_bit_width): - max_input_val = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_fc_val = self.weight_quant.max_uint_value(weight_bit_width) - max_output_val = max_input_val * max_fc_val * self.in_features - output_bit_width = ceil_ste(torch.log2(max_output_val)) - return output_bit_width diff --git a/src/brevitas/nn/quant_scale_bias.py b/src/brevitas/nn/quant_scale_bias.py index a97f54ed5..4b10aaa63 100644 --- a/src/brevitas/nn/quant_scale_bias.py +++ b/src/brevitas/nn/quant_scale_bias.py @@ -78,15 +78,20 @@ def channelwise_separable(self) -> bool: def forward(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: return self.forward_impl(inp) - def inner_forward_impl(self, input: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): + def inner_forward_impl( + self, + input: Union[Tensor, QuantTensor], + quant_weight: Union[Tensor, QuantTensor], + quant_bias: Optional[Union[Tensor, QuantTensor]]): quant_weight = quant_weight.view(self.runtime_shape) quant_bias = quant_bias.view(self.runtime_shape) - output_tensor = input * quant_weight + quant_bias - return output_tensor - def max_acc_bit_width(self, input_bit_width, weight_bit_width): - max_input_val = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_weight_val = self.weight_quant.max_uint_value(weight_bit_width) - max_output_val = max_input_val * max_weight_val - output_bit_width = ceil_ste(torch.log2(max_output_val)) - return output_bit_width + def biased_mul(input, weight, bias): + out = input * weight + if bias is not None: + out += bias + return out + + output_tensor = biased_mul(input, quant_weight, quant_bias) + + return output_tensor diff --git a/src/brevitas/nn/utils.py b/src/brevitas/nn/utils.py index ed5e87302..fccfbc4de 100644 --- a/src/brevitas/nn/utils.py +++ b/src/brevitas/nn/utils.py @@ -2,17 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause import torch -from torch import Tensor from torch.nn import Parameter -from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector -from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector - - -def compute_channel_view_shape(tensor: Tensor, channel_dim: int): - broadcast_shape = [1] * len(tensor.size()) - broadcast_shape[channel_dim] = -1 - return tuple(broadcast_shape) +from brevitas.utils.torch_utils import compute_channel_view_shape def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias): @@ -23,6 +15,8 @@ def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias): def merge_bn(layer, bn, output_channel_dim=0): + from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector + from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector out = mul_add_from_bn( bn_mean=bn.running_mean, bn_var=bn.running_var, diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 569ed71e0..7e58bc551 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -1,432 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from abc import ABC -from typing import NamedTuple, Optional -import warnings - -import torch -from torch import Tensor - -import brevitas.config as config -from brevitas.function.ops import max_int -from brevitas.function.ops import min_int -from brevitas.function.ops_ste import ceil_ste -from brevitas.function.ops_ste import round_ste - -from .torch_handler import QUANT_TENSOR_FN_HANDLER - -IS_VALID_ATOL = 2e-1 -BFLOAT16_IS_VALID_ATOL = 0.5 - - -class QuantTensorBase(NamedTuple): - value: Tensor - scale: Tensor - zero_point: Tensor - bit_width: Tensor - signed_t: Tensor - training_t: Tensor - - -def _unpack_quant_tensor(input_data): - if isinstance(input_data, QuantTensor): - return input_data.value - elif isinstance(input_data, tuple): - return tuple([_unpack_quant_tensor(v) for v in input_data]) - elif isinstance(input_data, list): - return [_unpack_quant_tensor(v) for v in input_data] - elif isinstance(input_data, dict): - return {k: _unpack_quant_tensor(v) for k, v in input_data.items()} - else: - return input_data - - -class QuantTensor(QuantTensorBase): - - def __new__(cls, value, scale, zero_point, bit_width, signed, training): - - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale, dtype=torch.float) - if not isinstance(zero_point, torch.Tensor): - zero_point = torch.tensor(zero_point, dtype=torch.float) - if not isinstance(bit_width, torch.Tensor): - bit_width = torch.tensor(bit_width, dtype=torch.float) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) - quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training) - return quant_tensor - - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - - def __torch_function__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if (func not in QUANT_TENSOR_FN_HANDLER or - not all(issubclass(t, QuantTensor) for t in types)): - args = _unpack_quant_tensor(args) - kwargs = _unpack_quant_tensor(kwargs) - return func(*args, **kwargs) - return QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) - - @property - def tensor(self): - return self.value - - @property - def _pre_round_int_value(self): - value = self.value - scale = self.scale - zero_point = self.zero_point - if self.scale.dtype == torch.bfloat16: - value = self.value.type(torch.float32) - scale = self.scale.type(torch.float32) - zero_point = self.zero_point.type(torch.float32) - int_value = value / scale - int_value = int_value + zero_point - return int_value - - @property - def is_valid(self): - with torch.no_grad(): - pre_round_int_value = self._pre_round_int_value - rounded_int_value = torch.round(pre_round_int_value) - max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) - atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL - is_int = max_abs_diff < atol - if self.bit_width >= 2: - if self.signed: - is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() - is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() - else: - is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() - is_lower_b = (0. <= rounded_int_value).all() - return (is_int & is_upper_b & is_lower_b).item() - else: # binary case - unique_vals = rounded_int_value.unique( - sorted=False, return_counts=False, return_inverse=False) - is_binary = unique_vals.view(-1).size()[0] == 2 - is_signed = (unique_vals < 0.).any().item() - sign_match = is_signed == self.signed - return is_int.item() and is_binary and sign_match - - @property - def device(self): - value_device = self.value.device - is_same_device = True - for t in [self.scale, self.zero_point, self.bit_width]: - is_same_device &= value_device == t.device - if not is_same_device: - raise RuntimeError("Value and metadata are on different devices") - return value_device - - def set(self, **kwargs): - return self._replace(**kwargs) - - def detach_(self): - self.value.detach_() - self.scale.detach_() - self.zero_point.detach_() - self.bit_width.detach_() - - def detach(self): - return QuantTensor( - self.value.detach(), - self.scale.detach(), - self.zero_point.detach(), - self.bit_width.detach(), - self.signed, - self.training) - - def contiguous(self): - return QuantTensor( - self.value.contiguous(), - self.scale.contiguous(), - self.zero_point.contiguous(), - self.bit_width.contiguous(), - self.signed, - self.training) - - def int(self, float_datatype=False): - if self.is_valid: - int_value = round_ste(self._pre_round_int_value) - if float_datatype: - # Values at 8bit and lower can be represented exactly with float16 and bfloat16 - # otherwise (e.g. Int16 bias), we upscale to float32 - if self.bit_width <= 8.: - return int_value.type(self.scale.dtype) - else: - return int_value.type(torch.float32) - else: - if self.bit_width <= 8. and self.signed_t.item(): - return int_value.to(torch.int8) - elif self.bit_width <= 8. and not self.signed_t.item(): - return int_value.to(torch.uint8) - else: - return int_value.to(torch.int32) - else: - raise RuntimeError(f"QuantTensor not valid.") - - @staticmethod - def check_input_type(tensor): - if not isinstance(tensor, QuantTensor): - raise RuntimeError("Tensor is not a QuantTensor") - - @staticmethod - def is_zero_zero_point(tensor): - QuantTensor.check_input_type(tensor) - return (tensor.zero_point == 0.).all() - - def check_scaling_factors_same(self, other): - if self.training: - return True - if not torch.allclose(self.scale, other.scale): - raise RuntimeError("Scaling factors are different") - - def check_zero_points_same(self, other): - if self.training: - return True - if not torch.allclose(self.zero_point, other.zero_point): - raise RuntimeError("Zero points are different") - - def check_bit_width_same(self, other): - if not torch.allclose(self.bit_width, other.bit_width): - raise RuntimeError("Bit widths are different") - - def check_sign_same(self, other): - if not self.signed == other.signed: - raise RuntimeError("Signs are different") - - def view(self, *args, **kwargs): - return self.set(value=self.value.view(*args, **kwargs)) - - def reshape(self, *args, **kwargs): - return self.set(value=self.value.reshape(*args, **kwargs)) - - def flatten(self, *args, **kwargs): - return self.set(value=self.value.flatten(*args, **kwargs)) - - def transpose(self, *args, **kwargs): - value = self.value.transpose(*args, **kwargs) - tensor_meta = { - 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} - for k, tm in tensor_meta.items(): - if len(value.shape) == len(tm.shape): - tensor_meta[k] = tm.transpose(*args, **kwargs) - return self.set(value=value, **tensor_meta) - - def permute(self, *args, **kwargs): - value = self.value.permute(*args, **kwargs) - tensor_meta = { - 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} - for k, tm in tensor_meta.items(): - if len(value.shape) == len(tm.shape): - tensor_meta[k] = tm.permute(*args, **kwargs) - return self.set(value=value, **tensor_meta) - - def size(self, *args, **kwargs): - return self.value.size(*args, **kwargs) - - @property - def shape(self): - return self.value.shape - - def dim(self): - return self.value.dim() - - def add(self, other): - return self + other - - @staticmethod - def cat(tensors, dim, out=None): - if out is not None: - raise RuntimeError("Out not supported.") - if len(tensors) < 2: - return tensors[0] - else: - first_qt = tensors[0] - if all([isinstance(qt, QuantTensor) for qt in tensors]): - for qt in tensors[1:]: - first_qt.check_scaling_factors_same(qt) - first_qt.check_zero_points_same(qt) - first_qt.check_bit_width_same(qt) - first_qt.check_sign_same(qt) - output_value = torch.cat([qt.value for qt in tensors], dim=dim) - output_training = any([qt.training for qt in tensors]) - if output_training: - output_scale = sum([qt.scale for qt in tensors]) / len(tensors) - output_zero_point = sum([qt.zero_point for qt in tensors]) / len(tensors) - output_bit_width = sum([qt.bit_width for qt in tensors]) / len(tensors) - else: # at eval time, they are the same - output_scale = first_qt.scale - output_zero_point = first_qt.zero_point - output_bit_width = first_qt.bit_width - output_signed = first_qt.signed # they are the same - return QuantTensor( - value=output_value, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=output_training) - else: - tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors] - output_value = torch.cat(tensors, dim=dim) - return output_value - - # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types - - def __neg__(self): - neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale - # In case the dtype of self.int is different from the one of the scale - neg_value = neg_value.type(self.scale.dtype) - if self.signed: - return QuantTensor( - value=neg_value, - scale=self.scale, - zero_point=self.zero_point, - bit_width=self.bit_width, - signed=self.signed, - training=self.training) - else: - return QuantTensor( - value=neg_value, - scale=self.scale, - zero_point=self.zero_point, - bit_width=self.bit_width + 1, - signed=True, - training=self.training) - - def to(self, *args, **kwargs): - return QuantTensor( - self.value.to(*args, **kwargs), - self.scale.to(*args, **kwargs), - self.zero_point.to(*args, **kwargs), - self.bit_width.to(*args, **kwargs), - self.signed, - self.training) - - def cuda(self, *args, **kwargs): - return QuantTensor( - self.value.cuda(*args, **kwargs), - self.scale.cuda(*args, **kwargs), - self.zero_point.cuda(*args, **kwargs), - self.bit_width.cuda(*args, **kwargs), - self.signed, - self.training) - - def cpu(self, *args, **kwargs): - return QuantTensor( - self.value.cpu(*args, **kwargs), - self.scale.cpu(*args, **kwargs), - self.zero_point.cpu(*args, **kwargs), - self.bit_width.cpu(*args, **kwargs), - self.signed, - self.training) - - def __add__(self, other): - if isinstance(other, QuantTensor): - self.check_scaling_factors_same(other) - output_value = self.value + other.value - output_scale = (self.scale + other.scale) / 2 - output_zero_point = self.zero_point + other.zero_point - max_val = max_int(signed=self.signed, narrow_range=False, bit_width=self.bit_width) - max_val += max_int(signed=other.signed, narrow_range=False, bit_width=other.bit_width) - min_val = min_int(signed=self.signed, narrow_range=False, bit_width=self.bit_width) - min_val += min_int(signed=other.signed, narrow_range=False, bit_width=other.bit_width) - output_bit_width = ceil_ste(torch.log2(max_val - min_val)) - output_signed = self.signed or other.signed - output_training = self.training or other.training - output = QuantTensor( - value=output_value, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=output_training) - else: - output = self.value + other - return output - - def __radd__(self, other): - return self.__add__(other) - - def __rmul__(self, other): - return self.__mul__(other) - - def __mul__(self, other): - if isinstance(other, QuantTensor): - output_value = self.value * other.value - output_scale = self.scale * other.scale - output_bit_width = self.bit_width + other.bit_width - output_signed = self.signed or other.signed - output_training = self.training or other.training - if self.is_zero_zero_point(self) and self.is_zero_zero_point(other): - output_zero_point = self.zero_point * other.zero_point - else: - raise RuntimeError("Zero-points of mul operands are non-zero, not supported.") - output = QuantTensor( - value=output_value, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=output_training) - else: - output = self.value * other - return output - - def __sub__(self, other): - return self.__add__(-other) - - def __str__(self): - return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" - - def __truediv__(self, other): - if isinstance(other, QuantTensor): - output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() - max_int_denominator = 2 ** (other.bit_width - int(other.signed)) - output_scale = self.scale / (other.scale * max_int_denominator) - output_bit_width = self.bit_width + other.bit_width - output_signed = self.signed or other.signed - output_training = self.training or other.training - if self.is_zero_zero_point(self) and self.is_zero_zero_point(other): - output_zero_point = self.zero_point * other.zero_point # Output zero_point is a new, zero-valued tensor - else: - raise RuntimeError("Zero-points of div operands are non-zero, not supported.") - output = QuantTensor( - value=output_tensor, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=output_training) - else: - output = self.value / other - return output - - def __abs__(self): - if self.signed: - abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale - # In case the dtype of self.int is different from the one of the scale - abs_value = abs_value.type(self.scale.dtype) - return QuantTensor( - value=abs_value, - scale=self.scale, - zero_point=self.zero_point, - bit_width=self.bit_width - 1, - signed=False, - training=self.training) - else: - return self - - def __pos__(self): - return self +from .base_quant_tensor import _unpack_quant_tensor +from .base_quant_tensor import QuantTensorBase +from .int_quant_tensor import QuantTensor diff --git a/src/brevitas/quant_tensor/base_quant_tensor.py b/src/brevitas/quant_tensor/base_quant_tensor.py new file mode 100644 index 000000000..e8d1439d7 --- /dev/null +++ b/src/brevitas/quant_tensor/base_quant_tensor.py @@ -0,0 +1,25 @@ +from typing import NamedTuple + +from torch import Tensor + + +class QuantTensorBase(NamedTuple): + value: Tensor + scale: Tensor + zero_point: Tensor + bit_width: Tensor + signed_t: Tensor + training_t: Tensor + + +def _unpack_quant_tensor(input_data): + if isinstance(input_data, QuantTensorBase): + return input_data.value + elif isinstance(input_data, tuple): + return tuple([_unpack_quant_tensor(v) for v in input_data]) + elif isinstance(input_data, list): + return [_unpack_quant_tensor(v) for v in input_data] + elif isinstance(input_data, dict): + return {k: _unpack_quant_tensor(v) for k, v in input_data.items()} + else: + return input_data diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py new file mode 100644 index 000000000..6539a327c --- /dev/null +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -0,0 +1,420 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch + +from brevitas.function.ops import max_int +from brevitas.function.ops import min_int +from brevitas.function.ops_ste import ceil_ste +from brevitas.function.ops_ste import round_ste +from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.quant_tensor import QuantTensorBase + +from .torch_handler import QUANT_TENSOR_FN_HANDLER + +IS_VALID_ATOL = 2e-1 +BFLOAT16_IS_VALID_ATOL = 0.5 + + +class QuantTensor(QuantTensorBase): + + def __new__(cls, value, scale, zero_point, bit_width, signed, training): + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float) + if not isinstance(zero_point, torch.Tensor): + zero_point = torch.tensor(zero_point, dtype=torch.float) + if not isinstance(bit_width, torch.Tensor): + bit_width = torch.tensor(bit_width, dtype=torch.float) + if not isinstance(signed, torch.Tensor): + signed = torch.tensor(signed, dtype=torch.bool) + if not isinstance(training, torch.Tensor): + training = torch.tensor(training, dtype=torch.bool) + quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training) + return quant_tensor + + @property + def signed(self): + if self.signed_t is not None: + return self.signed_t.item() + else: + return None + + @property + def training(self): + if self.training_t is not None: + return self.training_t.item() + else: + return None + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func not in QUANT_TENSOR_FN_HANDLER: + args = _unpack_quant_tensor(args) + kwargs = _unpack_quant_tensor(kwargs) + return func(*args, **kwargs) + return QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) + + @property + def tensor(self): + return self.value + + @property + def _pre_round_int_value(self): + value = self.value + scale = self.scale + zero_point = self.zero_point + if self.scale.dtype == torch.bfloat16: + value = self.value.type(torch.float32) + scale = self.scale.type(torch.float32) + zero_point = self.zero_point.type(torch.float32) + int_value = value / scale + int_value = int_value + zero_point + return int_value + + @property + def is_valid(self): + with torch.no_grad(): + pre_round_int_value = self._pre_round_int_value + rounded_int_value = torch.round(pre_round_int_value) + max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) + atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + is_int = max_abs_diff < atol + if self.bit_width >= 2: + if self.signed: + is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() + is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() + else: + is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() + is_lower_b = (0. <= rounded_int_value).all() + return (is_int & is_upper_b & is_lower_b).item() + else: # binary case + unique_vals = rounded_int_value.unique( + sorted=False, return_counts=False, return_inverse=False) + is_binary = unique_vals.view(-1).size()[0] == 2 + is_signed = (unique_vals < 0.).any().item() + sign_match = is_signed == self.signed + return is_int.item() and is_binary and sign_match + + @property + def device(self): + value_device = self.value.device + is_same_device = True + for t in [self.scale, self.zero_point, self.bit_width]: + if t is not None: + is_same_device &= value_device == t.device + if not is_same_device: + raise RuntimeError("Value and metadata are on different devices") + return value_device + + def set(self, **kwargs): + return self._replace(**kwargs) + + def detach_(self): + self.value.detach_() + self.scale.detach_() + self.zero_point.detach_() + self.bit_width.detach_() + + def detach(self): + return QuantTensor( + self.value.detach(), + self.scale.detach(), + self.zero_point.detach(), + self.bit_width.detach(), + self.signed, + self.training) + + def contiguous(self): + return QuantTensor( + self.value.contiguous(), + self.scale.contiguous(), + self.zero_point.contiguous(), + self.bit_width.contiguous(), + self.signed, + self.training) + + def int(self, float_datatype=False): + if self.is_valid: + int_value = round_ste(self._pre_round_int_value) + if float_datatype: + # Values at 8bit and lower can be represented exactly with float16 and bfloat16 + # otherwise (e.g. Int16 bias), we upscale to float32 + if self.bit_width <= 8.: + return int_value.type(self.scale.dtype) + else: + return int_value.type(torch.float32) + else: + if self.bit_width <= 8. and self.signed_t.item(): + return int_value.to(torch.int8) + elif self.bit_width <= 8. and not self.signed_t.item(): + return int_value.to(torch.uint8) + else: + return int_value.to(torch.int32) + else: + raise RuntimeError(f"QuantTensor not valid.") + + @staticmethod + def check_input_type(tensor): + if not isinstance(tensor, QuantTensor): + raise RuntimeError("Tensor is not a QuantTensor") + + @staticmethod + def is_zero_zero_point(tensor): + QuantTensor.check_input_type(tensor) + return (tensor.zero_point == 0.).all() + + def check_scaling_factors_same(self, other): + if self.training is not None and self.training: + return True + if not torch.allclose(self.scale, other.scale): + raise RuntimeError("Scaling factors are different") + + def check_zero_points_same(self, other): + if self.training is not None and self.training: + return True + if not torch.allclose(self.zero_point, other.zero_point): + raise RuntimeError("Zero points are different") + + def check_bit_width_same(self, other): + if not torch.allclose(self.bit_width, other.bit_width): + raise RuntimeError("Bit widths are different") + + def check_sign_same(self, other): + if not self.signed == other.signed: + raise RuntimeError("Signs are different") + + def view(self, *args, **kwargs): + return self.set(value=self.value.view(*args, **kwargs)) + + def reshape(self, *args, **kwargs): + return self.set(value=self.value.reshape(*args, **kwargs)) + + def flatten(self, *args, **kwargs): + return self.set(value=self.value.flatten(*args, **kwargs)) + + def transpose(self, *args, **kwargs): + value = self.value.transpose(*args, **kwargs) + tensor_meta = { + 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} + for k, tm in tensor_meta.items(): + if tm is not None and len(value.shape) == len(tm.shape): + tensor_meta[k] = tm.transpose(*args, **kwargs) + return self.set(value=value, **tensor_meta) + + def permute(self, *args, **kwargs): + value = self.value.permute(*args, **kwargs) + tensor_meta = { + 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} + for k, tm in tensor_meta.items(): + if tm is not None and len(value.shape) == len(tm.shape): + tensor_meta[k] = tm.permute(*args, **kwargs) + return self.set(value=value, **tensor_meta) + + def size(self, *args, **kwargs): + return self.value.size(*args, **kwargs) + + @property + def shape(self): + return self.value.shape + + def dim(self): + return self.value.dim() + + def add(self, other): + return self + other + + @staticmethod + def cat(tensors, dim, out=None): + if out is not None: + raise RuntimeError("Out not supported.") + if len(tensors) < 2: + return tensors[0] + else: + first_qt = tensors[0] + if all([isinstance(qt, QuantTensor) for qt in tensors]): + for qt in tensors[1:]: + first_qt.check_scaling_factors_same(qt) + first_qt.check_zero_points_same(qt) + first_qt.check_bit_width_same(qt) + first_qt.check_sign_same(qt) + output_value = torch.cat([qt.value for qt in tensors], dim=dim) + output_training = any([qt.training for qt in tensors]) + if output_training: + output_scale = sum([qt.scale for qt in tensors]) / len(tensors) + output_zero_point = sum([qt.zero_point for qt in tensors]) / len(tensors) + output_bit_width = sum([qt.bit_width for qt in tensors]) / len(tensors) + else: # at eval time, they are the same + output_scale = first_qt.scale + output_zero_point = first_qt.zero_point + output_bit_width = first_qt.bit_width + output_signed = first_qt.signed # they are the same + return QuantTensor( + value=output_value, + scale=output_scale, + zero_point=output_zero_point, + bit_width=output_bit_width, + signed=output_signed, + training=output_training) + else: + tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors] + output_value = torch.cat(tensors, dim=dim) + return output_value + + # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types + + def __neg__(self): + neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale + # In case the dtype of self.int is different from the one of the scale + neg_value = neg_value.type(self.scale.dtype) + if self.signed: + return QuantTensor( + value=neg_value, + scale=self.scale, + zero_point=self.zero_point, + bit_width=self.bit_width, + signed=self.signed, + training=self.training) + else: + return QuantTensor( + value=neg_value, + scale=self.scale, + zero_point=self.zero_point, + bit_width=self.bit_width + 1, + signed=True, + training=self.training) + + def to(self, *args, **kwargs): + return QuantTensor( + self.value.to(*args, **kwargs), + self.scale.to(*args, **kwargs), + self.zero_point.to(*args, **kwargs), + self.bit_width.to(*args, **kwargs), + self.signed, + self.training) + + def cuda(self, *args, **kwargs): + return QuantTensor( + self.value.cuda(*args, **kwargs), + self.scale.cuda(*args, **kwargs), + self.zero_point.cuda(*args, **kwargs), + self.bit_width.cuda(*args, **kwargs), + self.signed, + self.training) + + def cpu(self, *args, **kwargs): + return QuantTensor( + self.value.cpu(*args, **kwargs), + self.scale.cpu(*args, **kwargs), + self.zero_point.cpu(*args, **kwargs), + self.bit_width.cpu(*args, **kwargs), + self.signed, + self.training) + + def __add__(self, other): + if isinstance(other, QuantTensor): + self.check_scaling_factors_same(other) + output_value = self.value + other.value + output_scale = (self.scale + other.scale) / 2 + output_zero_point = self.zero_point + other.zero_point + max_val = max_int(signed=self.signed, narrow_range=False, bit_width=self.bit_width) + max_val += max_int(signed=other.signed, narrow_range=False, bit_width=other.bit_width) + min_val = min_int(signed=self.signed, narrow_range=False, bit_width=self.bit_width) + min_val += min_int(signed=other.signed, narrow_range=False, bit_width=other.bit_width) + output_bit_width = ceil_ste(torch.log2(max_val - min_val)) + output_signed = self.signed or other.signed + output_training = self.training or other.training + output = QuantTensor( + value=output_value, + scale=output_scale, + zero_point=output_zero_point, + bit_width=output_bit_width, + signed=output_signed, + training=output_training) + else: + # When adding a QT with a normal Tensor, we use the zero_point as a way to preserve + # and return a QT. + output = QuantTensor( + value=self.value + other, + scale=self.scale, + zero_point=self.zero_point - other / self.scale, + bit_width=self.bit_width, + signed=self.signed, + training=self.training) + return output + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __mul__(self, other): + if isinstance(other, QuantTensor): + output_value = self.value * other.value + output_scale = self.scale * other.scale + output_bit_width = self.bit_width + other.bit_width + output_signed = self.signed or other.signed + output_training = self.training or other.training + if self.is_zero_zero_point(self) and self.is_zero_zero_point(other): + output_zero_point = self.zero_point * other.zero_point + else: + raise RuntimeError("Zero-points of mul operands are non-zero, not supported.") + output = QuantTensor( + value=output_value, + scale=output_scale, + zero_point=output_zero_point, + bit_width=output_bit_width, + signed=output_signed, + training=output_training) + else: + output = self.value * other + return output + + def __sub__(self, other): + return self.__add__(-other) + + def __str__(self): + return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + + def __truediv__(self, other): + if isinstance(other, QuantTensor): + output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() + max_int_denominator = 2 ** (other.bit_width - int(other.signed)) + output_scale = self.scale / (other.scale * max_int_denominator) + output_bit_width = self.bit_width + other.bit_width + output_signed = self.signed or other.signed + output_training = self.training or other.training + if self.is_zero_zero_point(self) and self.is_zero_zero_point(other): + output_zero_point = self.zero_point * other.zero_point # Output zero_point is a new, zero-valued tensor + else: + raise RuntimeError("Zero-points of div operands are non-zero, not supported.") + output = QuantTensor( + value=output_tensor, + scale=output_scale, + zero_point=output_zero_point, + bit_width=output_bit_width, + signed=output_signed, + training=output_training) + else: + output = self.value / other + return output + + def __abs__(self): + if self.signed: + abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale + # In case the dtype of self.int is different from the one of the scale + abs_value = abs_value.type(self.scale.dtype) + return QuantTensor( + value=abs_value, + scale=self.scale, + zero_point=self.zero_point, + bit_width=self.bit_width - 1, + signed=False, + training=self.training) + else: + return self + + def __pos__(self): + return self diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 0f598db53..f0d9e95f8 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -2,11 +2,16 @@ # SPDX-License-Identifier: BSD-3-Clause import functools +import math +import warnings import torch import torch.nn.functional as F import brevitas +from brevitas.function.ops import max_int +from brevitas.function.ops_ste import ceil_ste +from brevitas.utils.torch_utils import compute_channel_view_shape QUANT_TENSOR_FN_HANDLER = {} @@ -156,3 +161,241 @@ def pixel_shuffle_handler(*args, **kwargs): @implements(F.pixel_unshuffle) def pixel_unshuffle_handler(*args, **kwargs): return quant_invariant_handler(F.pixel_unshuffle, *args, **kwargs) + + +@implements(F.conv1d) +def conv1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv1d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.conv2d) +def conv2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv2d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.conv3d) +def conv3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv3d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.conv_transpose1d) +def conv_transpose1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose1d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.conv_transpose2d) +def conv_transpose2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose2d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.conv_transpose3d) +def conv_transpose3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose3d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.linear) +def linear_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.linear, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.avg_pool2d) +def avg_pool2d_handler( + quant_input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override): + from brevitas.quant_tensor import _unpack_quant_tensor + + x = F.avg_pool2d( + _unpack_quant_tensor(quant_input), + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override) + + max_acc_bit_width = IMPLS[F.avg_pool2d] + # remove avg scaling + if isinstance(kernel_size, tuple): + avg_scaling = kernel_size[0] * kernel_size[1] + else: + avg_scaling = kernel_size * kernel_size + rescaled_value = x * avg_scaling + quant_input = quant_input.set(value=rescaled_value) + quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling)) + return quant_input + + +@implements(F.adaptive_avg_pool2d) +def adaptive_avg_pool2d_handler(quant_input, output_shape): + from functools import reduce + from operator import mul + + from brevitas.nn.quant_avg_pool import TruncAdaptiveAvgPool2d + from brevitas.quant_tensor import _unpack_quant_tensor + + x = F.adaptive_avg_pool2d(_unpack_quant_tensor(quant_input), output_shape) + k_size, stride = TruncAdaptiveAvgPool2d.compute_kernel_size_stride(quant_input.value.shape[2:], x.shape[2:]) + + max_acc_bit_width = IMPLS[F.avg_pool2d] + reduce_size = reduce(mul, k_size, 1) + rescaled_value = x * reduce_size # remove avg scaling + + quant_input = quant_input.set(value=rescaled_value) + quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) + return quant_input + + +def quant_layer(cls, quant_input, quant_weight, bias, *args, **kwargs): + from brevitas.quant_tensor import _unpack_quant_tensor + from brevitas.quant_tensor import QuantTensor + + output_scale = None + output_bit_width = None + output_zero_point = None + output_signed = None + max_acc_bit_width = IMPLS[cls] + + compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( + quant_weight, QuantTensor) + + if bias is None: + output = cls( + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + None, + *args, + **kwargs) + else: + output = cls( + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + _unpack_quant_tensor(bias), + *args, + **kwargs) + + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): + output_bit_width = max_acc_bit_width( + quant_input.bit_width, + quant_weight.bit_width, + quant_weight.value.shape, + *args, + **kwargs) + output_scale = quant_output_scale_impl( + cls, quant_input.value, quant_input.scale, quant_weight.scale) + output_signed = quant_input.signed or quant_weight.signed + output_training = quant_input.training or quant_weight.training + + if bias is not None: + if output_scale is not None: + if (isinstance(bias, QuantTensor) and + not torch.allclose(bias.scale, output_scale)) or not isinstance(bias, + QuantTensor): + channel_dim = -1 if isinstance(cls, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + quant_input, channel_dim=channel_dim) + output_zero_point = -_unpack_quant_tensor(bias).view( + output_scale_broadcast_shape) / output_scale + if output_bit_width is not None and isinstance(bias, QuantTensor): + output_bit_width = torch.where( + bias.bit_width > output_bit_width, bias.bit_width, output_bit_width) + output_bit_width = output_bit_width + 1 + + if compute_output_quant_tensor: + if (isinstance(quant_input, QuantTensor) and + (quant_input.zero_point != 0.0).any()) or (isinstance(quant_weight, QuantTensor) and + (quant_weight.zero_point != 0.0).any()): + warnings.warn("Computing zero point of output accumulator not supported yet.") + compute_output_quant_tensor = False + + if compute_output_quant_tensor: + if output_zero_point is None: + output_zero_point = torch.zeros(1).type_as(output) + + return create_quant_tensor( + output, + output_scale, + output_bit_width, + output_zero_point, + output_signed, + output_training) + else: + return output + + +def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training): + from brevitas.quant_tensor import QuantTensor + return QuantTensor( + tensor, + scale=scale, + zero_point=zero_point, + bit_width=bit_width, + signed=signed, + training=training) + + +def quant_output_scale_impl(cls, inp, quant_input_scale, quant_weight_scale): + channel_dim = -1 if cls == F.linear else 1 + output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) + output_scale = quant_weight_scale.view(output_scale_shape) + output_scale = output_scale * quant_input_scale.view(output_scale_shape) + return output_scale + + +def max_acc_bit_width_convNd(input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + in_channel = weight_shape[1] + kernel_size = math.prod(weight_shape[2:]) + max_uint_output = max_uint_input * max_kernel_val * kernel_size * in_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_linear(input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + in_channel = weight_shape[1] + max_uint_output = max_uint_input * max_kernel_val * in_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_convtransposeNd( + input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + stride = kwargs['stride'] + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + out_channel = weight_shape[1] + kernel_shape = weight_shape[2:] + + patch_size = 0 + for s, k in zip(stride, kernel_shape): + patch_size += max(math.ceil(k / s), 1) + + max_uint_output = max_uint_input * max_kernel_val * patch_size * out_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_avg_pool2d(input_bit_width, avg_scaling): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_uint_output = max_uint_input * avg_scaling + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +IMPLS = { + F.linear: max_acc_bit_width_linear, + F.conv1d: max_acc_bit_width_convNd, + F.conv2d: max_acc_bit_width_convNd, + F.conv3d: max_acc_bit_width_convNd, + F.conv_transpose1d: max_acc_bit_width_convtransposeNd, + F.conv_transpose2d: max_acc_bit_width_convtransposeNd, + F.conv_transpose3d: max_acc_bit_width_convtransposeNd, + F.avg_pool2d: max_acc_bit_width_avg_pool2d} diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index ec7d6fac4..9392c001d 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -80,3 +80,9 @@ def kthvalue( if x.dtype != dtype: x = x.type(dtype) return (x, indices) + + +def compute_channel_view_shape(tensor: torch.Tensor, channel_dim: int): + broadcast_shape = [1] * len(tensor.size()) + broadcast_shape[channel_dim] = -1 + return tuple(broadcast_shape) diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index bbee8daca..55dc42be2 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -53,8 +53,9 @@ def test_quant_wbiol(model_input, current_cases): return elif kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor'] and kwargs['io_quant'] is None \ and kwargs['input_quantized']: - with pytest.raises(RuntimeError, - match='Computing zero point of output accumulator not supported yet.'): + with pytest.raises( + AssertionError, + match='QuantLayer is not correctly configured, check if warnings were raised'): output = model(input) return else: @@ -188,8 +189,8 @@ def test_quant_mha(model_input, current_cases): elif kwargs['weight_quant'] is not None and kwargs['io_quant'] is None: if kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor']: with pytest.raises( - RuntimeError, - match='Computing zero point of output accumulator not supported yet.'): + AssertionError, + match='QuantLayer is not correctly configured, check if warnings were raised'): output, _ = model(inp, inp, inp) return output, _ = model(inp, inp, inp)