diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index f0b669903..04336dbbf 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -105,7 +105,8 @@ def __exit__(self, type, value, traceback): self.model, is_training=self.previous_training_state, quantization_enabled=True) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) -def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=True): + +def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): if hasattr(m, 'cache_inference_quant_bias'): if not hasattr(m, "cache_inference_quant_bias_backup"): m.cache_inference_quant_bias_backup = m.cache_inference_quant_bias @@ -113,14 +114,15 @@ def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool= m.cache_inference_quant_bias_metadata_only = metadata_only -def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=True): +def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): if hasattr(m, 'cache_inference_quant_act'): if not hasattr(m, "cache_inference_quant_act_backup"): m.cache_inference_quant_act_backup = m.cache_inference_quant_act m.cache_inference_quant_act = enabled m.cache_inference_quant_act_metadata_only = metadata_only -def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=False): + +def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False): if hasattr(m, 'cache_inference_quant_weight'): if not hasattr(m, "cache_inference_quant_weight_backup"): m.cache_inference_quant_weight_backup = m.cache_inference_quant_weight @@ -129,19 +131,20 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only:boo class inference_mode: + def __init__(self, model, cache_quant_weight=False, enabled=True): self.model = model - self.enabled=enabled + self.enabled = enabled self.cache_quant_weight = cache_quant_weight - + def __enter__(self): if self.enabled: - self.model.apply(lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True)) + self.model.apply( + lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True)) self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True)) if self.cache_quant_weight: self.model.apply(lambda m: _override_weight_caching_mode(m, enabled=True)) - def __exit__(self, type, value, traceback): self.return_quant_tensor_state = disable_return_quant_tensor(self.model) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index f4266e284..d4baf9e11 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -69,7 +69,6 @@ def __init__(self, return_quant_tensor: bool): def channelwise_separable(self) -> bool: pass - def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]): quant_tensor_classes = [ IntQuantTensor, FloatQuantTensor, GroupwiseIntQuantTensor, GroupwiseFloatQuantTensor] diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 9d4efb223..9bad29e79 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -1,5 +1,4 @@ -from typing import Optional, Union -from brevitas.quant_tensor import _unpack_quant_tensor +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -8,7 +7,9 @@ from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import FloatQuantTensor +from brevitas.quant_tensor.base_quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOFloat @@ -87,33 +88,20 @@ def is_fnuz(self): class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): - def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: - if self.is_quant_enabled: - if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - out = self._cached_weight.quant_tensor - if torch.compiler.is_compiling(): - out = _unpack_quant_tensor(out) - else: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) - if not torch.compiler.is_compiling(): - out = FloatQuantTensor( - out, - scale, - zero_point, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - self.is_signed, - self.training) - if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = _CachedIOFloat(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) - else: # quantization disabled - out = x - return out + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, QuantTensor]: + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args + return FloatQuantTensor( + out, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) class BiasFloatQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index fba9006e6..259cad757 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -60,108 +60,27 @@ def is_fnuz(self): ) is None and self.exponent_bias() == 16 return is_fnuz_e4m3 or is_fnuz_e5m2 - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - if torch.compiler.is_compiling(): - y = y[0] - else: - # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, FloatQuantTensor): - out = FloatQuantTensor( - y, - x.scale, - x.zero_point, - x.exponent_bit_width, - x.mantissa_bit_width, - x.exponent_bias, - x.saturating, - x.inf_values, - x.nan_values, - x.signed, - self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y - else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, - FloatQuantTensor): - cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out - return out - class ActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): - def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, FloatQuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, FloatQuantTensor): - out = FloatQuantTensor( - y, - x.scale, - x.zero_point, - x.mantissa_bit_width, - x.exponent_bit_width, - x.exponent_bias, - x.saturating, - x.inf_values, - x.nan_values, - x.signed, - self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y + def __init__(self, quant_layer, quant_injector): + super().__init__(self, quant_layer, quant_injector) + self.cache_class = _CachedIOFloat + + def create_quant_tensor(self, *qt_args, x=None): + if x is None: + out = FloatQuantTensor(qt_args, signed=self.is_signed, training=self.training) else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, - FloatQuantTensor): - cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out + out = FloatQuantTensor( + qt_args, + x.scale, + x.zero_point, + x.mantissa_bit_width, + x.exponent_bit_width, + x.exponent_bias, + x.saturating, + x.inf_values, + x.nan_values, + x.signed, + self.training) return out diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 5e3419fcd..c92af0c8c 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -1,11 +1,9 @@ -from typing import Union +from typing import Any, List, Union -import torch from torch import Tensor from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase -from brevitas.quant_tensor import GroupwiseFloatQuantTensor, _unpack_quant_tensor -from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat +from brevitas.quant_tensor import GroupwiseFloatQuantTensor class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): @@ -23,33 +21,19 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]: - if self.is_quant_enabled: - if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - out = self._cached_weight.quant_tensor - if torch.compiler.is_compiling(): - out = _unpack_quant_tensor(out) - else: - impl = self.export_handler if self.export_mode else self.tensor_quant - x = self.view_impl(x) - out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) - if not torch.compiler.is_compiling(): - out = GroupwiseFloatQuantTensor( - out, - scale, - zero_point, - self.group_size, - self.group_dim, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - self.is_signed, - self.training) - if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = _CachedIOGroupwiseFloat(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) - else: # quantization disabled - out = x - return out + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseFloatQuantTensor]: + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args + return GroupwiseFloatQuantTensor( + out, + scale, + zero_point, + self.group_size, + self.group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 7514471f9..f23301f37 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -1,16 +1,14 @@ -from typing import Union - -from torch import Tensor - from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor -from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat -import torch class GroupwiseActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): + def __init__(self, quant_layer, quant_injector): + super().__init__(self, quant_layer, quant_injector) + self.cache_class = _CachedIOGroupwiseFloat + @property def group_dim(self): return self.quant_injector.group_dim @@ -19,71 +17,36 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloatQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty GroupwiseFloatQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if torch.compiler.is_compiling(): - y = y[0] - else: - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y - out = GroupwiseFloatQuantTensor( - value, - scale, - zero_point, - self.group_size, - self.group_dim, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - signed=self.is_signed, - training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, GroupwiseFloatQuantTensor): - out = GroupwiseFloatQuantTensor( - y, - x.scale, - x.zero_point, - self.group_size, - self.group_dim, - x.exponent_bit_width, - x.mantissa_bit_width, - x.exponent_bias, - x.saturating, - x.inf_values, - x.nan_values, - x.signed, - self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y + def create_quant_tensor(self, *qt_args, x=None): + if x is None: + value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args + out = GroupwiseFloatQuantTensor( + value, + scale, + zero_point, + self.group_size, + self.group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + signed=self.is_signed, + training=self.training) else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance( - out, GroupwiseFloatQuantTensor): - cached_out = _CachedIOGroupwiseFloat(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out + out = GroupwiseFloatQuantTensor( + qt_args, + x.scale, + x.zero_point, + self.group_size, + self.group_dim, + x.exponent_bit_width, + x.mantissa_bit_width, + x.exponent_bias, + x.saturating, + x.inf_values, + x.nan_values, + x.signed, + self.training) return out diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index f41862660..fca5dd82d 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -1,10 +1,11 @@ -from typing import Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector -from brevitas.quant_tensor import GroupwiseIntQuantTensor, _unpack_quant_tensor +from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.quant_tensor import GroupwiseIntQuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseInt @@ -23,28 +24,14 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]: - if self.is_quant_enabled: - if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - out = self._cached_weight.quant_tensor - if torch.compiler.is_compiling(): - out = _unpack_quant_tensor(out) - else: - impl = self.export_handler if self.export_mode else self.tensor_quant - x = self.view_impl(x) - out, scale, zero_point, bit_width = impl(x) - if torch.compiler.is_compiling(): - out = GroupwiseIntQuantTensor( - out, - scale, - zero_point, - self.group_size, - self.group_dim, - bit_width, - self.is_signed, - self.training) - if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = _CachedIOGroupwiseInt(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) - else: # quantization disabled - out = x - return out + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseIntQuantTensor]: + out, scale, zero_point, bit_width = qt_args + return GroupwiseIntQuantTensor( + out, + scale, + zero_point, + self.group_size, + self.group_dim, + bit_width, + self.is_signed, + self.training) diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 959a799de..456b9228a 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -1,16 +1,20 @@ from typing import Union +import torch from torch import Tensor from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.quant_tensor import GroupwiseIntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseInt -import torch class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector): + def __init__(self, quant_layer, quant_injector): + super().__init__(self, quant_layer, quant_injector) + self.cache_class = _CachedIOGroupwiseInt + @property def group_dim(self): return self.quant_injector.group_dim @@ -19,61 +23,26 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseIntQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - if torch.compiler.is_compiling(): - y = y[0] - else: - # If y is an empty GroupwiseIntQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - value, scale, zero_point, bit_width, = y - out = GroupwiseIntQuantTensor( - value, - scale, - zero_point, - self.group_size, - self.group_dim, - bit_width, - signed=self.is_signed, - training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, GroupwiseIntQuantTensor): - out = GroupwiseIntQuantTensor( - y, - x.scale, - x.zero_point, - self.group_size, - self.group_dim, - x.bit_width, - x.signed, - self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y + def create_quant_tensor(self, *qt_args, x=None): + if x is None: + value, scale, zero_point, bit_width, = qt_args + out = GroupwiseIntQuantTensor( + value, + scale, + zero_point, + self.group_size, + self.group_dim, + bit_width, + self.is_signed, + self.training) else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance( - out, GroupwiseIntQuantTensor): - cached_out = _CachedIOGroupwiseInt(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out + out = GroupwiseIntQuantTensor( + qt_args, + x.scale, + x.zero_point, + self.group_size, + self.group_dim, + x.bit_width, + x.signed, + self.training) return out diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 6bd423170..20395878a 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -4,9 +4,18 @@ from abc import ABC from abc import ABCMeta from abc import abstractmethod -from typing import Optional, Union +from typing import Any, List, Optional, Union from warnings import warn +import packaging.version + +from brevitas import torch_version + +if torch_version < packaging.version.parse('2.0'): + is_dynamo_compiling = lambda _: False +else: + is_dynamo_compiling = torch._dynamo.is_compiling + import torch from torch import Tensor import torch.nn as nn @@ -16,8 +25,9 @@ from brevitas import config from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import IntQuantTensor -from brevitas.quant_tensor import QuantTensor, _unpack_quant_tensor +from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO from brevitas.utils.torch_utils import compute_channel_view_shape @@ -87,8 +97,20 @@ class WeightQuantProxyFromInjectorBase(ParameterQuantProxyFromInjector, def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) self._cached_weight = None - self.cache_inference_quant_weight = False + self._cache_inference_quant_weight = False self.cache_inference_quant_weight_metadata_only = False + self.cache_class = None # To be redefined by each class + self.quant_tensor_class = None # To be redefined by each class + + @property + def cache_inference_quant_weight(self): + return self._cache_inference_quant_weight + + @cache_inference_quant_weight.setter + def cache_inference_quant_weight(self, value): + if not value: + self._cached_weight = None + self._cache_inference_quant_weight = value @property def tracked_parameter_list(self): @@ -98,6 +120,32 @@ def tracked_parameter_list(self): def requires_quant_input(self): return False + @abstractmethod + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, QuantTensor]: + raise NotImplementedError + + def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: + if self.is_quant_enabled: + if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: + out = self._cached_weight.quant_tensor + # Test this + if is_dynamo_compiling(): + out = _unpack_quant_tensor(out) + else: + impl = self.export_handler if self.export_mode else self.tensor_quant + out = impl(x) + if not is_dynamo_compiling(): + out = self.create_quant_tensor(*out) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = self.cache_class( + out.detach(), + metadata_only=self.cache_inference_quant_weight_metadata_only) + else: + out = out[0] + else: # quantization disabled + out = x + return out + class BiasQuantProxyFromInjectorBase(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol, ABC): @@ -107,16 +155,17 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_inference_quant_bias = False self.cache_inference_quant_bias_metadata_only = False self.requires_input_scale = self.quant_injector.requires_input_scale and self.is_quant_enabled - + @property def tracked_parameter_list(self): return [m.bias for m in self.tracked_module_list if m.bias is not None] def get_cached(self, attr): if self._cached_bias is None: - # warn( - # "No quant bias cache found, set cache_inference_quant_bias=True and run an " - # "inference pass first") + if not is_dynamo_compiling(): + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") return None if self.training: warn("Cached quant bias scale is being used in training mode.") @@ -125,6 +174,9 @@ def get_cached(self, attr): class WeightQuantProxyFromInjector(WeightQuantProxyFromInjectorBase): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIO @property def tracked_parameter_list(self): @@ -152,23 +204,8 @@ def bit_width(self): bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width return bit_width - def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: - if self.is_quant_enabled: - if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - out = self._cached_weight.quant_tensor - if torch.compiler.is_compiling(): - out = _unpack_quant_tensor(out) - else: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width = impl(x) - if not torch.compiler.is_compiling(): - out = IntQuantTensor( - out, scale, zero_point, bit_width, self.is_signed, self.training) - if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = _CachedIO(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) - else: # quantization disabled - out = x - return out + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, QuantTensor]: + return IntQuantTensor(*qt_args, self.is_signed, self.training) class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): @@ -187,22 +224,9 @@ def pre_zero_point(self): out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point - def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: - if self.is_quant_enabled: - if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - out = self._cached_weight.quant_tensor - if torch.compiler.is_compiling(): - out = _unpack_quant_tensor(out) - else: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x) - if not torch.compiler.is_compiling(): - out = IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) - if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = _CachedIO(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) - else: # quantization disabled - out = x - return out + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, QuantTensor]: + out, scale, zero_point, bit_width, pre_scale, pre_zero_point = qt_args + return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector): @@ -326,11 +350,13 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - if not torch.compiler.is_compiling(): + if not is_dynamo_compiling(): out = IntQuantTensor( out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) if not self.training and self.cache_inference_quant_bias and self._cached_bias is not None: - cached_bias = _CachedIO(out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) + cached_bias = _CachedIO( + out.detach(), + metadata_only=self.cache_inference_quant_bias_metadata_only) self._cached_bias = cached_bias else: out = x diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 857748732..1b847b280 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -104,7 +104,6 @@ def init_tensor_quant(self): def is_quant_enabled(self): return not self.disable_quant and self.tensor_quant is not None - @property def is_groupwise(self): return _is_groupwise(self.quant_injector) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index a8e4b1137..512dde45e 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -2,9 +2,19 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC +from abc import abstractmethod from typing import Optional, Tuple, Union +import packaging.version import torch + +from brevitas import torch_version + +if torch_version < packaging.version.parse('2.0'): + is_dynamo_compiling = lambda _: False +else: + is_dynamo_compiling = torch._dynamo.is_compiling + from torch import nn from torch import Tensor from torch.nn import Identity @@ -18,6 +28,7 @@ from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol + __all__ = [ 'ActQuantProxyProtocol', 'AccQuantProxyProtocol', @@ -94,6 +105,7 @@ def __init__(self, quant_layer, quant_injector): self._cached_act = None self.cache_inference_quant_act = False self.cache_quant_io_metadata_only = True + self.cache_class = None def internal_forward(self, force_eval): current_status = self.training @@ -139,9 +151,55 @@ def init_tensor_quant(self): else: self.fused_activation_quant_proxy = None + @abstractmethod + def create_quant_tensor(self, *qt_args, x=None): + raise NotImplementedError + + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + # If fused activation quant proxy is not enabled, return the input + if self.fused_activation_quant_proxy is None: + return x + + y = x + if isinstance(y, QuantTensor): + y = y.value + + if self.export_mode: + y = self.fused_activation_quant_proxy.activation_impl(y) + y = self.export_handler(y) + elif not self.is_quant_enabled: + # A tuple helps later with control flows + # The second None value is used later + y = (self.fused_activation_quant_proxy.activation_impl(y), None) + else: + y = self.fused_activation_quant_proxy(y) + # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + if is_dynamo_compiling(): + out = y[0] + else: + # If the second value (i.e., scale) is None, then quant is disabled + if isinstance(y, tuple) and y[1] is not None: + out = self.create_quant_tensor(y) + elif self.is_passthrough_act and isinstance(x, QuantTensor): + # preserve scale/zp/bit/sign even without output quant + y = y[0] + out = self.create_quant_tensor(y, x=x) + else: + out = y[0] + + if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): + cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out + class ActQuantProxyFromInjector(ActQuantProxyFromInjectorBase): + def __init__(self, quant_layer, quant_injector): + super().__init__(self, quant_layer, quant_injector) + self.cache_class = _CachedIO + def scale(self, force_eval=True): return self.retrieve_attribute('scale', force_eval) @@ -151,41 +209,12 @@ def zero_point(self, force_eval=True): def bit_width(self, force_eval=True): return self.retrieve_attribute('bit_width', force_eval) - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, IntQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - if torch.compiler.is_compiling(): - out = y[0] - else: - if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): - out = IntQuantTensor(*y, signed=self.is_signed, training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, IntQuantTensor): - out = IntQuantTensor( - y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) - else: - out = y + def create_quant_tensor(self, *qt_args, x=None): + if x is None: + out = IntQuantTensor(qt_args, self.is_signed, self.training) else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, IntQuantTensor): - cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out + out = IntQuantTensor( + qt_args, x.scale, x.zero_point, x.bit_width, x.signed, self.training) return out diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 5b2dece46..8fd2f655d 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -7,7 +7,6 @@ import random import warnings -from brevitas.graph.calibrate import inference_mode import numpy as np import torch import torch.backends.cudnn as cudnn @@ -19,6 +18,7 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas.graph.calibrate import inference_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize