diff --git a/src/brevitas/fx/value_tracer.py b/src/brevitas/fx/value_tracer.py index 6cab89767..55ea0d93b 100644 --- a/src/brevitas/fx/value_tracer.py +++ b/src/brevitas/fx/value_tracer.py @@ -57,7 +57,7 @@ import torch.utils._pytree as pytree from brevitas import torch_version -from brevitas.quant_tensor import QuantTensorBase +from brevitas.quant_tensor import QuantTensor from . import * from . import _assert_is_none @@ -82,7 +82,7 @@ from . import ScopeContextManager _UNSET = object() -extended_base_types = base_types + (QuantTensorBase,) +extended_base_types = base_types + (QuantTensor,) FRAME_FILES = [ 'fx/brevitas_tracer.py', diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 893ff5e30..82e2a6ac4 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -15,6 +15,7 @@ from brevitas import config from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector +from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -103,11 +104,11 @@ def bit_width(self): bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width return bit_width - def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: + def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width = impl(x) - return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x @@ -128,11 +129,11 @@ def pre_zero_point(self): out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point - def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: + def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x) - return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x @@ -157,11 +158,12 @@ def pre_zero_point(self): raise NotImplementedError def forward( - self, - x: torch.Tensor, - quant_input: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]: + self, + x: torch.Tensor, + quant_input: Optional[Union[Tensor, + IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]: if isinstance(quant_input, - QuantTensor) and not self.training and self.cache_inference_quant_act: + IntQuantTensor) and not self.training and self.cache_inference_quant_act: cached_inp = _CachedIO(quant_input.detach(), self.cache_quant_io_metadata_only) self._cached_act = cached_inp @@ -170,14 +172,14 @@ def forward( assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass" quant_input = self._cached_act else: - assert isinstance(quant_input, QuantTensor), "Input must be quantized" + assert isinstance(quant_input, IntQuantTensor), "Input must be quantized" input_bit_width = quant_input.bit_width input_is_signed = quant_input.signed impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) - return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x @@ -236,7 +238,7 @@ def bit_width(self): def forward(self, x: Tensor, - input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]: + input_scale: Optional[Tensor] = None) -> Union[Tensor, IntQuantTensor]: out = x if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant @@ -251,10 +253,12 @@ def forward(self, else: out, out_scale, out_zp, out_bit_width = impl(x) - out = QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + out = IntQuantTensor( + out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) else: out = x - if isinstance(out, QuantTensor) and not self.training and self.cache_inference_quant_bias: + if isinstance(out, + IntQuantTensor) and not self.training and self.cache_inference_quant_bias: cached_bias = _CachedIO(out.detach(), metadata_only=False) self._cached_bias = cached_bias return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 4dd8417a9..4ec268ede 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -10,6 +10,7 @@ from typing_extensions import runtime_checkable import brevitas +from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -166,11 +167,11 @@ def bit_width(self, force_eval=True): elif self._cached_act is None: return None - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]: out = x if self.fused_activation_quant_proxy is not None: y = x - if isinstance(y, QuantTensor): + if isinstance(y, IntQuantTensor): y = y.value if self.export_mode: @@ -180,15 +181,15 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: y = self.fused_activation_quant_proxy.activation_impl(y) else: y = self.fused_activation_quant_proxy(y) - # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, + # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): - out = QuantTensor(*y, signed=self.is_signed, training=self.training) + out = IntQuantTensor(*y, signed=self.is_signed, training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): y = y[0] - if isinstance(x, QuantTensor): - out = QuantTensor( + if isinstance(x, IntQuantTensor): + out = IntQuantTensor( y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) else: out = y @@ -199,7 +200,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: else: # If fused activation quant proxy is not enabled, return the input out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): + if not self.training and self.cache_inference_quant_act and isinstance(out, IntQuantTensor): cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only) self._cached_act = cached_out return out @@ -216,11 +217,11 @@ def zero_point(self, force_eval=True): class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): - def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: + def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple - return QuantTensor( + return IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training) return x @@ -232,11 +233,11 @@ def bit_width(self): return None zhs = self._zero_hw_sentinel() # Signed might or might not be defined. We just care about retrieving the bitwidth - empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) + empty_imp = IntQuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) bit_width = self.__call__(empty_imp).bit_width return bit_width - def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: + def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: if self.export_mode: out_tuple = self.export_handler( @@ -244,7 +245,8 @@ def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple - return QuantTensor(out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) + return IntQuantTensor( + out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) else: return x diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 7e58bc551..04891713e 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -1,6 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from .base_quant_tensor import * from .base_quant_tensor import _unpack_quant_tensor -from .base_quant_tensor import QuantTensorBase -from .int_quant_tensor import QuantTensor +from .int_quant_tensor import * diff --git a/src/brevitas/quant_tensor/base_quant_tensor.py b/src/brevitas/quant_tensor/base_quant_tensor.py index 9791b9cbd..ad2c50cbb 100644 --- a/src/brevitas/quant_tensor/base_quant_tensor.py +++ b/src/brevitas/quant_tensor/base_quant_tensor.py @@ -3,7 +3,11 @@ from torch import Tensor -class QuantTensorBase(NamedTuple): +class QuantTensor: + pass + + +class IntTensorBase(NamedTuple): value: Tensor scale: Optional[Tensor] zero_point: Optional[Tensor] @@ -13,7 +17,7 @@ class QuantTensorBase(NamedTuple): def _unpack_quant_tensor(input_data): - if isinstance(input_data, QuantTensorBase): + if isinstance(input_data, QuantTensor): return input_data.value elif isinstance(input_data, tuple): return tuple([_unpack_quant_tensor(v) for v in input_data]) diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index f923e516e..4122f9bf9 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -8,7 +8,8 @@ from brevitas.function.ops_ste import ceil_ste from brevitas.function.ops_ste import round_ste from brevitas.quant_tensor import _unpack_quant_tensor -from brevitas.quant_tensor import QuantTensorBase +from brevitas.quant_tensor import IntTensorBase +from brevitas.quant_tensor import QuantTensor from .torch_handler import QUANT_TENSOR_FN_HANDLER @@ -16,7 +17,7 @@ BFLOAT16_IS_VALID_ATOL = 0.5 -class QuantTensor(QuantTensorBase): +class IntQuantTensor(IntTensorBase, QuantTensor): def __new__(cls, value, scale, zero_point, bit_width, signed, training): @@ -118,7 +119,7 @@ def detach_(self): self.bit_width.detach_() def detach(self): - return QuantTensor( + return IntQuantTensor( self.value.detach(), self.scale.detach(), self.zero_point.detach(), @@ -127,7 +128,7 @@ def detach(self): self.training) def contiguous(self): - return QuantTensor( + return IntQuantTensor( self.value.contiguous(), self.scale.contiguous(), self.zero_point.contiguous(), @@ -153,16 +154,16 @@ def int(self, float_datatype=False): else: return int_value.to(torch.int32) else: - raise RuntimeError(f"QuantTensor not valid.") + raise RuntimeError(f"IntQuantTensor not valid.") @staticmethod def check_input_type(tensor): - if not isinstance(tensor, QuantTensor): - raise RuntimeError("Tensor is not a QuantTensor") + if not isinstance(tensor, IntQuantTensor): + raise RuntimeError("Tensor is not a IntQuantTensor") @staticmethod def is_zero_zero_point(tensor): - QuantTensor.check_input_type(tensor) + IntQuantTensor.check_input_type(tensor) return (tensor.zero_point == 0.).all() def check_scaling_factors_same(self, other): @@ -233,7 +234,7 @@ def cat(tensors, dim, out=None): return tensors[0] else: first_qt = tensors[0] - if all([isinstance(qt, QuantTensor) for qt in tensors]): + if all([isinstance(qt, IntQuantTensor) for qt in tensors]): for qt in tensors[1:]: first_qt.check_scaling_factors_same(qt) first_qt.check_zero_points_same(qt) @@ -250,7 +251,7 @@ def cat(tensors, dim, out=None): output_zero_point = first_qt.zero_point output_bit_width = first_qt.bit_width output_signed = first_qt.signed # they are the same - return QuantTensor( + return IntQuantTensor( value=output_value, scale=output_scale, zero_point=output_zero_point, @@ -258,7 +259,7 @@ def cat(tensors, dim, out=None): signed=output_signed, training=output_training) else: - tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors] + tensors = [qt.value if isinstance(qt, IntQuantTensor) else qt for qt in tensors] output_value = torch.cat(tensors, dim=dim) return output_value @@ -269,7 +270,7 @@ def __neg__(self): # In case the dtype of self.int is different from the one of the scale neg_value = neg_value.type(self.scale.dtype) if self.signed: - return QuantTensor( + return IntQuantTensor( value=neg_value, scale=self.scale, zero_point=self.zero_point, @@ -277,7 +278,7 @@ def __neg__(self): signed=self.signed, training=self.training) else: - return QuantTensor( + return IntQuantTensor( value=neg_value, scale=self.scale, zero_point=self.zero_point, @@ -286,7 +287,7 @@ def __neg__(self): training=self.training) def to(self, *args, **kwargs): - return QuantTensor( + return IntQuantTensor( self.value.to(*args, **kwargs), self.scale.to(*args, **kwargs), self.zero_point.to(*args, **kwargs), @@ -295,7 +296,7 @@ def to(self, *args, **kwargs): self.training) def cuda(self, *args, **kwargs): - return QuantTensor( + return IntQuantTensor( self.value.cuda(*args, **kwargs), self.scale.cuda(*args, **kwargs), self.zero_point.cuda(*args, **kwargs), @@ -304,7 +305,7 @@ def cuda(self, *args, **kwargs): self.training) def cpu(self, *args, **kwargs): - return QuantTensor( + return IntQuantTensor( self.value.cpu(*args, **kwargs), self.scale.cpu(*args, **kwargs), self.zero_point.cpu(*args, **kwargs), @@ -313,7 +314,7 @@ def cpu(self, *args, **kwargs): self.training) def __add__(self, other): - if isinstance(other, QuantTensor): + if isinstance(other, IntQuantTensor): self.check_scaling_factors_same(other) output_value = self.value + other.value output_scale = (self.scale + other.scale) / 2 @@ -325,14 +326,14 @@ def __add__(self, other): output_bit_width = ceil_ste(torch.log2(max_val - min_val)) output_signed = self.signed or other.signed output_training = self.training or other.training - output = QuantTensor( + output = IntQuantTensor( value=output_value, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): + elif isinstance(other, IntQuantTensor): output = self.value + other.value else: output = self.value + other @@ -345,7 +346,7 @@ def __rmul__(self, other): return self.__mul__(other) def __mul__(self, other): - if isinstance(other, QuantTensor): + if isinstance(other, IntQuantTensor): output_value = self.value * other.value output_scale = self.scale * other.scale output_bit_width = self.bit_width + other.bit_width @@ -355,14 +356,14 @@ def __mul__(self, other): output_zero_point = self.zero_point * other.zero_point else: raise RuntimeError("Zero-points of mul operands are non-zero, not supported.") - output = QuantTensor( + output = IntQuantTensor( value=output_value, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): + elif isinstance(other, IntQuantTensor): output = self.value * other.value else: output = self.value * other @@ -372,10 +373,10 @@ def __sub__(self, other): return self.__add__(-other) def __str__(self): - return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"IntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" def __truediv__(self, other): - if isinstance(other, QuantTensor): + if isinstance(other, IntQuantTensor): output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() max_int_denominator = 2 ** (other.bit_width - int(other.signed)) output_scale = self.scale / (other.scale * max_int_denominator) @@ -386,14 +387,14 @@ def __truediv__(self, other): output_zero_point = self.zero_point * other.zero_point # Output zero_point is a new, zero-valued tensor else: raise RuntimeError("Zero-points of div operands are non-zero, not supported.") - output = QuantTensor( + output = IntQuantTensor( value=output_tensor, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): + elif isinstance(other, IntQuantTensor): output = self.value / other.value else: output = self.value / other @@ -404,7 +405,7 @@ def __abs__(self): abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale # In case the dtype of self.int is different from the one of the scale abs_value = abs_value.type(self.scale.dtype) - return QuantTensor( + return IntQuantTensor( value=abs_value, scale=self.scale, zero_point=self.zero_point, @@ -421,7 +422,7 @@ def __pos__(self): def max_acc_bit_width(cls, *args): def _max_int_or_tensor(args): - if isinstance(args, QuantTensor): + if isinstance(args, IntQuantTensor): return max_int(bit_width=args.bit_width, signed=False, narrow_range=False) else: return args diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 92795c839..3ad57da91 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -246,7 +246,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): rescaled_value = x * reduce_size # remove avg scaling quant_input = quant_input.set(value=rescaled_value) - quant_input = quant_input.set(bit_width=max_acc_bit_width(x.bit_width, reduce_size)) + quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) return quant_input @@ -324,8 +324,8 @@ def quant_layer( def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training): - from brevitas.quant_tensor import QuantTensor - return QuantTensor( + from brevitas.quant_tensor import IntQuantTensor + return IntQuantTensor( tensor, scale=scale, zero_point=zero_point,