diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 92228b7a3..f0b669903 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -105,6 +105,46 @@ def __exit__(self, type, value, traceback): self.model, is_training=self.previous_training_state, quantization_enabled=True) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) +def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=True): + if hasattr(m, 'cache_inference_quant_bias'): + if not hasattr(m, "cache_inference_quant_bias_backup"): + m.cache_inference_quant_bias_backup = m.cache_inference_quant_bias + m.cache_inference_quant_bias = enabled + m.cache_inference_quant_bias_metadata_only = metadata_only + + +def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=True): + if hasattr(m, 'cache_inference_quant_act'): + if not hasattr(m, "cache_inference_quant_act_backup"): + m.cache_inference_quant_act_backup = m.cache_inference_quant_act + m.cache_inference_quant_act = enabled + m.cache_inference_quant_act_metadata_only = metadata_only + +def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=False): + if hasattr(m, 'cache_inference_quant_weight'): + if not hasattr(m, "cache_inference_quant_weight_backup"): + m.cache_inference_quant_weight_backup = m.cache_inference_quant_weight + m.cache_inference_quant_weight = enabled + m.cache_inference_quant_weight_metadata_only = metadata_only + + +class inference_mode: + def __init__(self, model, cache_quant_weight=False, enabled=True): + self.model = model + self.enabled=enabled + self.cache_quant_weight = cache_quant_weight + + def __enter__(self): + if self.enabled: + self.model.apply(lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True)) + self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True)) + if self.cache_quant_weight: + self.model.apply(lambda m: _override_weight_caching_mode(m, enabled=True)) + + + def __exit__(self, type, value, traceback): + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + class bias_correction_mode: diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 167852508..f4266e284 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -69,8 +69,6 @@ def __init__(self, return_quant_tensor: bool): def channelwise_separable(self) -> bool: pass - def _set_global_is_quant_layer(self, value): - config._IS_INSIDE_QUANT_LAYER = value def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]): quant_tensor_classes = [ @@ -81,7 +79,6 @@ def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]): return None def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - self._set_global_is_quant_layer(True) # Hack to recognize a QuantTensor that has decayed to a tuple # when used as input to tracing (e.g. during ONNX export) if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and @@ -89,7 +86,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe qt_class = self.get_quant_tensor_class(inp) if qt_class is not None: inp = qt_class(*inp) - if not torch._C._get_tracing_state(): + if not torch._C._get_tracing_state() and not torch.compiler.is_compiling(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) else: @@ -97,7 +94,6 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe return inp def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - self._set_global_is_quant_layer(False) if self.return_quant_tensor: assert isinstance(quant_output, QuantTensor), 'QuantLayer is not correctly configured, check if warnings were raised' return quant_output diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 68038fa20..9d4efb223 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -1,5 +1,5 @@ from typing import Optional, Union -from warnings import warn +from brevitas.quant_tensor import _unpack_quant_tensor import torch from torch import Tensor @@ -84,46 +84,36 @@ def is_fnuz(self): ) is None and self.exponent_bias() == 16 return is_fnuz_e4m3 or is_fnuz_e5m2 - def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: - if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) - return FloatQuantTensor( - out, - scale, - zero_point, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - self.is_signed, - self.training) - else: # quantization disabled - return x - class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) - return FloatQuantTensor( - out, - scale, - zero_point, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - self.is_signed, - self.training) + if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: + out = self._cached_weight.quant_tensor + if torch.compiler.is_compiling(): + out = _unpack_quant_tensor(out) + else: + impl = self.export_handler if self.export_mode else self.tensor_quant + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) + if not torch.compiler.is_compiling(): + out = FloatQuantTensor( + out, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = _CachedIOFloat(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - return x + out = x + return out class BiasFloatQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 021aefd12..fba9006e6 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -74,33 +74,36 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTens y = self.fused_activation_quant_proxy.activation_impl(y) else: y = self.fused_activation_quant_proxy(y) - # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, FloatQuantTensor): - out = FloatQuantTensor( - y, - x.scale, - x.zero_point, - x.exponent_bit_width, - x.mantissa_bit_width, - x.exponent_bias, - x.saturating, - x.inf_values, - x.nan_values, - x.signed, - self.training) + if torch.compiler.is_compiling(): + y = y[0] + else: + # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + # We exclude the last two values (inf_values and nan_values) + if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): + out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, FloatQuantTensor): + out = FloatQuantTensor( + y, + x.scale, + x.zero_point, + x.exponent_bit_width, + x.mantissa_bit_width, + x.exponent_bias, + x.saturating, + x.inf_values, + x.nan_values, + x.signed, + self.training) + else: + out = y else: + if isinstance(y, tuple): + y = y[0] out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y else: # If fused activation quant proxy is not enabled, return the input out = x diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index cd38d9906..5e3419fcd 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -4,7 +4,8 @@ from torch import Tensor from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase -from brevitas.quant_tensor import GroupwiseFloatQuantTensor +from brevitas.quant_tensor import GroupwiseFloatQuantTensor, _unpack_quant_tensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): @@ -24,22 +25,31 @@ def group_size(self): def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]: if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - x = self.view_impl(x) - out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) - return GroupwiseFloatQuantTensor( - out, - scale, - zero_point, - self.group_size, - self.group_dim, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - self.is_signed, - self.training) + if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: + out = self._cached_weight.quant_tensor + if torch.compiler.is_compiling(): + out = _unpack_quant_tensor(out) + else: + impl = self.export_handler if self.export_mode else self.tensor_quant + x = self.view_impl(x) + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) + if not torch.compiler.is_compiling(): + out = GroupwiseFloatQuantTensor( + out, + scale, + zero_point, + self.group_size, + self.group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = _CachedIOGroupwiseFloat(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - return x + out = x + return out diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 4ab182d20..7514471f9 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -6,6 +6,7 @@ from brevitas.quant_tensor import GroupwiseFloatQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat +import torch class GroupwiseActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): @@ -35,46 +36,49 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat # If y is an empty GroupwiseFloatQuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y - out = GroupwiseFloatQuantTensor( - value, - scale, - zero_point, - self.group_size, - self.group_dim, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - signed=self.is_signed, - training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, GroupwiseFloatQuantTensor): + if torch.compiler.is_compiling(): + y = y[0] + else: + if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): + value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y out = GroupwiseFloatQuantTensor( - y, - x.scale, - x.zero_point, + value, + scale, + zero_point, self.group_size, self.group_dim, - x.exponent_bit_width, - x.mantissa_bit_width, - x.exponent_bias, - x.saturating, - x.inf_values, - x.nan_values, - x.signed, - self.training) + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + signed=self.is_signed, + training=self.training) + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, GroupwiseFloatQuantTensor): + out = GroupwiseFloatQuantTensor( + y, + x.scale, + x.zero_point, + self.group_size, + self.group_dim, + x.exponent_bit_width, + x.mantissa_bit_width, + x.exponent_bias, + x.saturating, + x.inf_values, + x.nan_values, + x.signed, + self.training) + else: + out = y else: + if isinstance(y, tuple): + y = y[0] out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y else: # If fused activation quant proxy is not enabled, return the input out = x diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 035ee9729..f41862660 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -4,7 +4,8 @@ from torch import Tensor from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector -from brevitas.quant_tensor import GroupwiseIntQuantTensor +from brevitas.quant_tensor import GroupwiseIntQuantTensor, _unpack_quant_tensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseInt class GroupwiseWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): @@ -24,17 +25,26 @@ def group_size(self): def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]: if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - x = self.view_impl(x) - out, scale, zero_point, bit_width = impl(x) - return GroupwiseIntQuantTensor( - out, - scale, - zero_point, - self.group_size, - self.group_dim, - bit_width, - self.is_signed, - self.training) + if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: + out = self._cached_weight.quant_tensor + if torch.compiler.is_compiling(): + out = _unpack_quant_tensor(out) + else: + impl = self.export_handler if self.export_mode else self.tensor_quant + x = self.view_impl(x) + out, scale, zero_point, bit_width = impl(x) + if torch.compiler.is_compiling(): + out = GroupwiseIntQuantTensor( + out, + scale, + zero_point, + self.group_size, + self.group_dim, + bit_width, + self.is_signed, + self.training) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = _CachedIOGroupwiseInt(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - return x + out = x + return out diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index e9788e89b..959a799de 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -6,6 +6,7 @@ from brevitas.quant_tensor import GroupwiseIntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseInt +import torch class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector): @@ -32,39 +33,42 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseIntQu y = self.fused_activation_quant_proxy.activation_impl(y) else: y = self.fused_activation_quant_proxy(y) - # If y is an empty GroupwiseIntQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - value, scale, zero_point, bit_width, = y - out = GroupwiseIntQuantTensor( - value, - scale, - zero_point, - self.group_size, - self.group_dim, - bit_width, - signed=self.is_signed, - training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, GroupwiseIntQuantTensor): + if torch.compiler.is_compiling(): + y = y[0] + else: + # If y is an empty GroupwiseIntQuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + # We exclude the last two values (inf_values and nan_values) + if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): + value, scale, zero_point, bit_width, = y out = GroupwiseIntQuantTensor( - y, - x.scale, - x.zero_point, + value, + scale, + zero_point, self.group_size, self.group_dim, - x.bit_width, - x.signed, - self.training) + bit_width, + signed=self.is_signed, + training=self.training) + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, GroupwiseIntQuantTensor): + out = GroupwiseIntQuantTensor( + y, + x.scale, + x.zero_point, + self.group_size, + self.group_dim, + x.bit_width, + x.signed, + self.training) + else: + out = y else: + if isinstance(y, tuple): + y = y[0] out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y else: # If fused activation quant proxy is not enabled, return the input out = x diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index fc4e75cb9..6bd423170 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -17,7 +17,7 @@ from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector from brevitas.quant_tensor import IntQuantTensor -from brevitas.quant_tensor import QuantTensor +from brevitas.quant_tensor import QuantTensor, _unpack_quant_tensor from brevitas.utils.quant_utils import _CachedIO from brevitas.utils.torch_utils import compute_channel_view_shape @@ -84,6 +84,12 @@ class WeightQuantProxyFromInjectorBase(ParameterQuantProxyFromInjector, WeightQuantProxyProtocol, ABC): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self._cached_weight = None + self.cache_inference_quant_weight = False + self.cache_inference_quant_weight_metadata_only = False + @property def tracked_parameter_list(self): return [m.weight for m in self.tracked_module_list if m.weight is not None] @@ -99,23 +105,18 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) self._cached_bias = None self.cache_inference_quant_bias = False - + self.cache_inference_quant_bias_metadata_only = False + self.requires_input_scale = self.quant_injector.requires_input_scale and self.is_quant_enabled + @property def tracked_parameter_list(self): return [m.bias for m in self.tracked_module_list if m.bias is not None] - @property - def requires_input_scale(self) -> bool: - if self.is_quant_enabled: - return self.quant_injector.requires_input_scale - else: - return False - def get_cached(self, attr): if self._cached_bias is None: - warn( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") + # warn( + # "No quant bias cache found, set cache_inference_quant_bias=True and run an " + # "inference pass first") return None if self.training: warn("Cached quant bias scale is being used in training mode.") @@ -124,10 +125,6 @@ def get_cached(self, attr): class WeightQuantProxyFromInjector(WeightQuantProxyFromInjectorBase): - def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: - super().__init__(quant_layer, quant_injector) - self._cached_weight = None - self.cache_inference_quant_weight = False @property def tracked_parameter_list(self): @@ -157,19 +154,20 @@ def bit_width(self): def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: - if self._cached_weight is not None: + if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: out = self._cached_weight.quant_tensor + if torch.compiler.is_compiling(): + out = _unpack_quant_tensor(out) else: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width = impl(x) - out = IntQuantTensor( - out, scale, zero_point, bit_width, self.is_signed, self.training) + if not torch.compiler.is_compiling(): + out = IntQuantTensor( + out, scale, zero_point, bit_width, self.is_signed, self.training) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = _CachedIO(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled out = x - if isinstance( - out, IntQuantTensor - ) and not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = _CachedIO(out.detach(), metadata_only=False) return out @@ -191,11 +189,20 @@ def pre_zero_point(self): def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x) - return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: + out = self._cached_weight.quant_tensor + if torch.compiler.is_compiling(): + out = _unpack_quant_tensor(out) + else: + impl = self.export_handler if self.export_mode else self.tensor_quant + out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x) + if not torch.compiler.is_compiling(): + out = IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = _CachedIO(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - return x + out = x + return out class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector): @@ -304,24 +311,27 @@ def forward( out = x input_scale = self.compute_bias_scale(input, weight) if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - if self.requires_input_scale and input_scale is None: - input_scale = self.scale() - if input_scale is None: - raise RuntimeError("Input scale required") - - if self.requires_input_scale: - input_scale = input_scale.view(-1) - out, out_scale, out_zp, out_bit_width = impl(x, input_scale) + if self._cached_bias is not None and not self.cache_inference_quant_bias_metadata_only: + out = self._cached_bias.value else: - out, out_scale, out_zp, out_bit_width = impl(x) - - out = IntQuantTensor( - out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + impl = self.export_handler if self.export_mode else self.tensor_quant + if self.requires_input_scale and input_scale is None: + input_scale = self.scale() + if input_scale is None: + raise RuntimeError("Input scale required") + elif self.requires_input_scale and input_scale is not None: + input_scale = input_scale.view(-1) + + if self.requires_input_scale: + out, out_scale, out_zp, out_bit_width = impl(x, input_scale) + else: + out, out_scale, out_zp, out_bit_width = impl(x) + if not torch.compiler.is_compiling(): + out = IntQuantTensor( + out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + if not self.training and self.cache_inference_quant_bias and self._cached_bias is not None: + cached_bias = _CachedIO(out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) + self._cached_bias = cached_bias else: out = x - if isinstance(out, - IntQuantTensor) and not self.training and self.cache_inference_quant_bias: - cached_bias = _CachedIO(out.detach(), metadata_only=False) - self._cached_bias = cached_bias return out diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 3a680035e..857748732 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -28,12 +28,6 @@ def _is_groupwise(quant_injector): return False -def _is_signed(quant_injector): - if 'signed' in quant_injector: - return quant_injector.signed - return None - - def _is_narrow_range(quant_injector): if 'narrow_range' in quant_injector: return quant_injector.narrow_range @@ -88,6 +82,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.tracked_module_list = [] self.add_tracked_module(quant_layer) self.disable_quant = False + # Torch.compile compatibility requires this + self.is_signed = quant_injector.signed if 'signed' in quant_injector else None @property def requires_export_handler(self): @@ -108,9 +104,6 @@ def init_tensor_quant(self): def is_quant_enabled(self): return not self.disable_quant and self.tensor_quant is not None - @property - def is_signed(self): - return _is_signed(self.quant_injector) @property def is_groupwise(self): diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index a89bc9abb..a8e4b1137 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -4,6 +4,7 @@ from abc import ABC from typing import Optional, Tuple, Union +import torch from torch import nn from torch import Tensor from torch.nn import Identity @@ -17,7 +18,6 @@ from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol - __all__ = [ 'ActQuantProxyProtocol', 'AccQuantProxyProtocol', @@ -116,12 +116,6 @@ def retrieve_attribute(self, attribute, force_eval): def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant - @property - def is_signed(self): - if self._cached_act is not None: - return self._cached_act.signed - return super().is_signed - @is_quant_enabled.setter def is_quant_enabled(self, is_quant_enabled): self._is_quant_enabled = is_quant_enabled @@ -173,20 +167,19 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, IntQuantTensor y = self.fused_activation_quant_proxy(y) # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): - out = IntQuantTensor(*y, signed=self.is_signed, training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, IntQuantTensor): - out = IntQuantTensor( - y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) - else: - out = y + if torch.compiler.is_compiling(): + out = y[0] else: - if isinstance(y, tuple): - y = y[0] - out = y + if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): + out = IntQuantTensor(*y, signed=self.is_signed, training=self.training) + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, IntQuantTensor): + out = IntQuantTensor( + y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) + else: + out = y else: # If fused activation quant proxy is not enabled, return the input out = x diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index f160877a0..6fd519b41 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -16,13 +16,14 @@ class _CachedIO: def __init__(self, quant_tensor: IntQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: + self.value = None self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale + # torch.compile compatibility + self.value = quant_tensor.value + # torch.compile compatibility + self.scale = quant_tensor.scale @property def zero_point(self): @@ -42,13 +43,14 @@ class _CachedIOFloat: def __init__(self, quant_tensor: FloatQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: + self.value = None self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale + # torch.compile compatibility + self.value = quant_tensor.value + # torch.compile compatibility + self.scale = quant_tensor.scale @property def zero_point(self): @@ -88,13 +90,14 @@ class _CachedIOGroupwiseFloat: def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: + self.value = None self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale + # torch.compile compatibility + self.value = quant_tensor.value + # torch.compile compatibility + self.scale = quant_tensor.scale @property def zero_point(self): @@ -142,13 +145,14 @@ class _CachedIOGroupwiseInt: def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: + self.value = None self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale + # torch.compile compatibility + self.value = quant_tensor.value + # torch.compile compatibility + self.scale = quant_tensor.scale @property def zero_point(self): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 7e2bf6ee5..5b2dece46 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -7,6 +7,7 @@ import random import warnings +from brevitas.graph.calibrate import inference_mode import numpy as np import torch import torch.backends.cudnn as cudnn @@ -365,6 +366,7 @@ def main(): # Get the model from torchvision model = get_torchvision_model(args.model_name) model = model.to(dtype) + model.eval() # Preprocess the model for quantization if args.target_backend == 'flexml':