From b28ac0faa7826fa6155dcc72bc4ee3725b5ef8e4 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 00:10:32 +0800 Subject: [PATCH] Feat (torch.compile): prototype support for compiled inference (#1006) --- noxfile.py | 2 +- src/brevitas/__init__.py | 6 + src/brevitas/core/function_wrapper/clamp.py | 49 +++--- src/brevitas/export/common/handler/qcdq.py | 4 +- src/brevitas/export/inference/__init__.py | 5 + src/brevitas/export/inference/handler.py | 153 ++++++++++++++++++ src/brevitas/export/inference/manager.py | 106 ++++++++++++ src/brevitas/export/manager.py | 8 +- src/brevitas/graph/calibrate.py | 2 +- src/brevitas/nn/mixin/base.py | 6 +- src/brevitas/proxy/float_parameter_quant.py | 8 + src/brevitas/proxy/float_runtime_quant.py | 10 +- .../proxy/groupwise_int_runtime_quant.py | 2 +- src/brevitas/proxy/parameter_quant.py | 54 ++++--- src/brevitas/proxy/quant_proxy.py | 2 +- src/brevitas/proxy/runtime_quant.py | 37 +++-- .../imagenet_classification/ptq/ptq_common.py | 3 +- .../ptq/ptq_evaluate.py | 30 +++- .../imagenet_classification/utils.py | 2 - .../test_torchvision_models.py | 50 +++++- 20 files changed, 461 insertions(+), 78 deletions(-) create mode 100644 src/brevitas/export/inference/__init__.py create mode 100644 src/brevitas/export/inference/handler.py create mode 100644 src/brevitas/export/inference/manager.py diff --git a/noxfile.py b/noxfile.py index 8aad90528..0d6680460 100644 --- a/noxfile.py +++ b/noxfile.py @@ -211,4 +211,4 @@ def tests_brevitas_end_to_end(session, pytorch): install_pytorch(pytorch, session) install_torchvision(pytorch, session) session.install('--upgrade', '-e', '.[test, ort_integration]') - session.run('pytest', '-v', 'tests/brevitas_end_to_end') + session.run('pytest', '-n', 'logical', '-v', 'tests/brevitas_end_to_end') diff --git a/src/brevitas/__init__.py b/src/brevitas/__init__.py index eddc35a02..fe46102a7 100644 --- a/src/brevitas/__init__.py +++ b/src/brevitas/__init__.py @@ -23,6 +23,12 @@ else: torch_version = version.parse(torch.__version__) +try: + # Attempt _dynamo import + is_dynamo_compiling = torch._dynamo.is_compiling +except: + is_dynamo_compiling = lambda: False + try: __version__ = get_distribution(__name__).version except DistributionNotFound: diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 70d1fc23f..163e63a22 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -113,6 +113,29 @@ def __init__( else: self.max_available_float = None + def inf_nan_clamp(self, x, inf_mask, p_max_val_mask, n_max_val_mask): + + # if non-saturating, we need to map values greater than max_val to nan or inf + if self.inf_values is not None: + # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf + x[p_max_val_mask] = torch.tensor(float('inf')) + x[n_max_val_mask] = torch.tensor(float('-inf')) + elif self.nan_values is not None: + # no inf values, so we need to map them to NaN + full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) + x[full_max_val_mask] = torch.tensor(float('nan')) + + # we also map the inf values to NaN in this case + x[inf_mask] = torch.tensor(float('nan')) + else: + raise RuntimeError( + "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified" + ) + return x + + def saturating_clamp(self, x, max_value, min_value): + return self.tensor_clamp_impl(x, min_val=min_value, max_val=max_value) + @brevitas.jit.script_method def forward( self, @@ -120,33 +143,21 @@ def forward( exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): - inf_mask = x.isinf() + max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) + min_value = torch.tensor(0.) if not self.signed else -max_value + + # Compute masks + inf_mask = x.isinf() p_max_val_mask = x > max_value n_max_val_mask = -x > max_value - min_float = torch.tensor(0.) if not self.signed else -max_value # first clamp everything to +- max_value, basically the saturating case - x = self.tensor_clamp_impl(x, min_val=min_float, max_val=max_value) + x = self.saturating_clamp(x, max_value, min_value) if not self.saturating: - # if non-saturating, we need to map values greater than max_val to nan or inf - if self.inf_values is not None: - # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf - x[p_max_val_mask] = torch.tensor(float('inf')) - x[n_max_val_mask] = torch.tensor(float('-inf')) - elif self.nan_values is not None: - # no inf values, so we need to map them to NaN - full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) - x[full_max_val_mask] = torch.tensor(float('nan')) - - # we also map the inf values to NaN in this case - x[inf_mask] = torch.tensor(float('nan')) - else: - raise RuntimeError( - "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified" - ) + x = self.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask) return x, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 44061ce42..bbc03b630 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -454,7 +454,7 @@ def prepare_for_export(self, module): self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width() self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width() self.symbolic_kwargs['exponent_bias'] = module.exponent_bias() - self.symbolic_kwargs['saturating'] = module.saturating() + self.symbolic_kwargs['saturating'] = module.is_saturating() self.symbolic_kwargs['inf_values'] = module.inf_values() self.symbolic_kwargs['nan_values'] = module.nan_values() @@ -659,7 +659,7 @@ def prepare_for_export(self, module): 'exponent_bit_width': module.exponent_bit_width(), 'mantissa_bit_width': module.mantissa_bit_width(), 'exponent_bias': module.exponent_bias(), - 'saturating': module.saturating(), + 'saturating': module.is_saturating(), 'inf_values': module.inf_values(), 'nan_values': module.nan_values()} diff --git a/src/brevitas/export/inference/__init__.py b/src/brevitas/export/inference/__init__.py new file mode 100644 index 000000000..0e6d113e0 --- /dev/null +++ b/src/brevitas/export/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from .manager import InferenceManager +from .manager import quant_inference_mode diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py new file mode 100644 index 000000000..1416014ec --- /dev/null +++ b/src/brevitas/export/inference/handler.py @@ -0,0 +1,153 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + +from brevitas.function.ops import max_float +from brevitas.function.ops import max_int +from brevitas.function.ops import min_int +from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase +from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.utils.torch_utils import float_internal_scale + + +class InferenceHandler(torch.nn.Module, ABC): + + def attach_debug_info(self, module): + pass + + @abstractmethod + def prepare_for_export(self, module): + pass + + @abstractmethod + def quantize(self, x): + pass + + @abstractmethod + def dequantize(self, x): + pass + + +class IntInferencetHandler(InferenceHandler): + handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector) + + def attach_debug_info(self, module): + pass + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.scale = module.scale() + self.zero_point = module.zero_point().to(self.scale.device) + self.bit_width = module.bit_width() + self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) + self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) + + def quantize(self, x): + return torch.clamp( + torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp) + + def dequantize(self, x): + return (x - self.zero_point) * self.scale + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width + + +class IntWeightInferencetHandler(IntInferencetHandler): + handled_layer = WeightQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.cached_weight = None + super().prepare_for_export(module) + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight.value + + def forward(self, x) -> Tuple[torch.Tensor]: + if self.cached_weight is not None: + x = self.cached_weight + else: + x = self.dequantize(self.quantize(x)) + return x, self.scale, self.zero_point, self.bit_width + + +class FloatInferencetHandler(InferenceHandler): + handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector) + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.scale = module.scale() + self.zero_point = module.zero_point().to(self.scale.device) + self.exponent_bit_width = module.exponent_bit_width() + self.mantissa_bit_width = module.mantissa_bit_width() + self.exponent_bias = module.exponent_bias() + self.saturating = module.is_saturating() + self.inf_values = module.inf_values() + self.nan_values = module.nan_values() + self.eps = torch.finfo(self.scale.dtype).tiny + if hasattr(module.tensor_quant, 'float_to_int_impl'): + self.float_to_int_impl = module.tensor_quant.float_to_int_impl + self.float_clamp_impl = module.tensor_quant.float_clamp_impl + elif hasattr(module, 'fused_activation_quant_proxy'): + self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl + self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl + + self.max_clamp = max_float( + self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) + self.min_clamp = -self.max_clamp + self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width + self.max_value = max_float( + self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) + self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value + + def quantize(self, x): + # Compute masks + inf_mask = x.isinf() + p_max_val_mask = x > self.max_value + n_max_val_mask = -x > self.max_value + + # Quantize + x = x / self.scale + internal_scale = float_internal_scale( + x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) + x = internal_scale * self.float_to_int_impl(x / internal_scale) + + # Clamp + x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value) + if not self.saturating: + x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask) + + return x + + def dequantize(self, x): + return (x - self.zero_point) * self.scale + + def forward(self, x) -> Tuple[torch.Tensor]: + return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + + +class FloatWeightInferencetHandler(FloatInferencetHandler): + handled_layer = WeightFloatQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.cached_weight = None + super().prepare_for_export(module) + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight.value + + def forward(self, x) -> Tuple[torch.Tensor]: + if self.cached_weight is not None: + x = self.cached_weight + else: + x = self.dequantize(self.quantize(x)) + return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py new file mode 100644 index 000000000..936106884 --- /dev/null +++ b/src/brevitas/export/inference/manager.py @@ -0,0 +1,106 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from torch.nn import Module +import torch.nn as nn + +from brevitas.export.inference.handler import FloatInferencetHandler +from brevitas.export.inference.handler import FloatWeightInferencetHandler +from brevitas.export.inference.handler import IntInferencetHandler +from brevitas.export.inference.handler import IntWeightInferencetHandler +from brevitas.export.manager import _set_proxy_export_handler +from brevitas.export.manager import _set_proxy_export_mode +from brevitas.export.manager import _set_recurrent_layer_export_handler +from brevitas.export.manager import _set_recurrent_layer_export_mode +from brevitas.export.manager import BaseManager +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import restore_return_quant_tensor + + +def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True): + cache_var = 'cache_inference_quant_' + attr + cache_var_metadata_only = cache_var + '_metadata_only' + if hasattr(m, cache_var): + setattr(m, cache_var, enabled) + setattr(m, cache_var_metadata_only, metadata_only) + + +def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): + _override_caching_mode(m, 'bias', enabled, metadata_only) + + +def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): + _override_caching_mode(m, 'act', enabled, metadata_only) + + +def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False): + _override_caching_mode(m, 'weight', enabled, metadata_only) + + +class quant_inference_mode: + + def __init__(self, model, cache_quant_weight=False, enabled=True): + self.model = model + self.enabled = enabled + self.cache_quant_weight = cache_quant_weight + self.export_manager = InferenceManager + self.hook_list = [] + self.return_quant_tensor_state = dict() + + def __enter__(self): + if self.enabled: + # Register the hook and store it in the list so that it can be removed by the hook itself when called + handle = self.model.register_forward_hook(self.hook) + self.hook_list.append(handle) + + # Enable bias for everything. Optionally, store the fully fake-quantized weights + self.model.apply( + lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True)) + self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True)) + self.model.apply( + lambda m: _override_weight_caching_mode( + m, enabled=True, metadata_only=not self.cache_quant_weight)) + + def __exit__(self, type, value, traceback): + # Disable all caching + # deactivate export mode + # restore return quant tensor + self.model.apply( + lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False)) + self.model.apply( + lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False)) + if self.cache_quant_weight: + self.model.apply( + lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) + InferenceManager.set_export_mode(self.model, enabled=False) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + + def hook(self, module, inp, out): + # After one forward pass with caching enabled, we can: + # - Set the model in export mode + # - Attach export handlers + # - Disable return quant tensor since all quant metadata is cached + assert len(self.hook_list) == 1 + self.hook_list[0].remove() + self.model.apply(InferenceManager.set_export_handler) + InferenceManager.set_export_mode(self.model, enabled=True) + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + + +# Inheritance from BaseManager is not techincally needed +class InferenceManager(BaseManager): + handlers = [ + IntInferencetHandler, + FloatInferencetHandler, + IntWeightInferencetHandler, + FloatWeightInferencetHandler] + + @classmethod + def set_export_mode(cls, model: Module, enabled: bool): + _set_proxy_export_mode(model, enabled) + _set_recurrent_layer_export_mode(model, enabled) + + @classmethod + def set_export_handler(cls, module: Module): + _set_proxy_export_handler(cls, module) + _set_recurrent_layer_export_handler(cls, module) diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 2805c6174..7b7e7a145 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -166,11 +166,15 @@ def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs): @classmethod def handler_from_module(cls, module: Module, no_inheritance=False): for handler in cls.handlers: + if not isinstance(handler.handled_layer, tuple): + handled_classes = (handler.handled_layer,) + else: + handled_classes = handler.handled_layer if no_inheritance: - if type(module) == handler.handled_layer: + if type(module) in handled_classes: return handler else: - if isinstance(module, handler.handled_layer): + if any([isinstance(module, handler) for handler in handled_classes]): return handler return None diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 92228b7a3..2b1f6833e 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -58,7 +58,7 @@ def disable_return_quant_tensor(model): def restore_return_quant_tensor(model, previous_state): for module in model.modules(): - if hasattr(module, 'return_quant_tensor'): + if hasattr(module, 'return_quant_tensor') and module in previous_state: module.return_quant_tensor = previous_state[module] diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index d64271cb5..a5c4407fd 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -8,12 +8,16 @@ from typing import Optional, Tuple, Union from warnings import warn +import packaging.version +import torch from torch import nn from torch import Tensor import torch.jit from torch.nn.utils.rnn import PackedSequence from brevitas import config +from brevitas import is_dynamo_compiling +from brevitas import torch_version from brevitas.common import ExportMixin from brevitas.inject import ExtendedInjector from brevitas.inject import Injector @@ -85,7 +89,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe qt_class = self.get_quant_tensor_class(inp) if qt_class is not None: inp = qt_class(*inp) - if not torch._C._get_tracing_state(): + if not torch._C._get_tracing_state() and not is_dynamo_compiling(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) else: diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 4e6452792..0d6ffd106 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -4,6 +4,7 @@ from torch import Tensor import torch.nn as nn +from brevitas.core.function_wrapper.misc import Identity from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase @@ -83,6 +84,13 @@ def is_fnuz(self): ) is None and self.exponent_bias() == 16 return is_fnuz_e4m3 or is_fnuz_e5m2 + @property + def input_view_impl(self): + if self.tensor_quant is not None: + return self.tensor_quant.input_view_impl + else: + return Identity() + class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index b38f4ecdb..7350e5e32 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from brevitas.core.function_wrapper.misc import Identity from brevitas.inject import BaseInjector as Injector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase from brevitas.quant_tensor import FloatQuantTensor @@ -27,7 +28,7 @@ def mantissa_bit_width(self, force_eval=True): def exponent_bias(self, force_eval=True): return self.retrieve_attribute('exponent_bias', force_eval) - def saturating(self, force_eval=True): + def is_saturating(self, force_eval=True): return self.retrieve_attribute('saturating', force_eval) def inf_values(self, force_eval=True): @@ -36,6 +37,13 @@ def inf_values(self, force_eval=True): def nan_values(self, force_eval=True): return self.retrieve_attribute('nan_values', force_eval) + @property + def input_view_impl(self): + if self.fused_activation_quant_proxy.tensor_quant is not None: + return self.fused_activation_quant_proxy.tensor_quant.input_view_impl + else: + return Identity() + @property def is_ocp(self): is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4 diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index ec9418e19..453cb3f9b 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -31,7 +31,7 @@ def create_quant_tensor( qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: if x is None: - value, scale, zero_point, bit_width, = qt_args + value, scale, zero_point, bit_width = qt_args out = GroupwiseIntQuantTensor( value, scale, diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 77a806ee8..f28233aed 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -4,7 +4,7 @@ from abc import ABC from abc import ABCMeta from abc import abstractmethod -from typing import Any, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from warnings import warn import torch @@ -14,8 +14,11 @@ from typing_extensions import runtime_checkable from brevitas import config +from brevitas import is_dynamo_compiling +from brevitas.core.function_wrapper.misc import Identity from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -92,6 +95,13 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = None # To be redefined by each class self.quant_tensor_class = None # To be redefined by each class + @property + def input_view_impl(self): + if self.tensor_quant is not None: + return self.tensor_quant.int_quant.input_view_impl + else: + return Identity() + @property def cache_inference_quant_weight(self): return self._cache_inference_quant_weight @@ -118,19 +128,23 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: # If quant is enabled the priority is: # - export mode - # - cached weight # - quantization flow if self.export_mode: out = self.export_handler(x) - out = self.create_quant_tensor(out) - elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - out = self._cached_weight.quant_tensor + if is_dynamo_compiling(): + out = out[0] + else: + out = self.create_quant_tensor(out) else: out = self.tensor_quant(x) - out = self.create_quant_tensor(out) - if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = self.cache_class( - out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) + if is_dynamo_compiling(): + out = out[0] + else: + out = self.create_quant_tensor(out) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = self.cache_class( + out.detach(), + metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled out = self.apply_input_view(x) return out @@ -151,9 +165,10 @@ def tracked_parameter_list(self): def get_cached(self, attr): if self._cached_bias is None: - warn( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") + if not is_dynamo_compiling(): + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") return None if self.training: warn("Cached quant bias scale is being used in training mode.") @@ -268,7 +283,7 @@ class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): def scale(self): if not self.is_quant_enabled: return None - if self.requires_input_scale and self.is_quant_enabled and self.is_quant_enabled: + if self.requires_input_scale and self.is_quant_enabled: cache = self.get_cached('scale') return cache zhs = self._zero_hw_sentinel() @@ -335,12 +350,13 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - out = IntQuantTensor( - out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) - if not self.training and self.cache_inference_quant_bias: - cached_bias = _CachedIO( - out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) - self._cached_bias = cached_bias + if not is_dynamo_compiling(): + out = IntQuantTensor( + out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + if not self.training and self.cache_inference_quant_bias: + cached_bias = _CachedIO( + out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) + self._cached_bias = cached_bias else: out = x return out diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 9c4255773..845bfd515 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -122,7 +122,7 @@ def add_tracked_module(self, module: nn.Module) -> None: raise RuntimeError("Trying to add None as a parent module.") def apply_input_view(self, x): - return self.quant_injector.input_view_impl(x) + return self.input_view_impl(x) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 511f914e6..9feb593b4 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -13,6 +13,7 @@ from typing_extensions import runtime_checkable import brevitas +from brevitas import is_dynamo_compiling from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -98,6 +99,14 @@ def __init__(self, quant_layer, quant_injector): self.cache_quant_io_metadata_only = True self.cache_class = None + @property + def input_view_impl(self): + if self.fused_activation_quant_proxy.tensor_quant is not None and not isinstance( + self.fused_activation_quant_proxy.tensor_quant, _TensorQuantDisabledIdentity): + return self.fused_activation_quant_proxy.tensor_quant.int_quant.input_view_impl + else: + return Identity() + def internal_forward(self, force_eval): current_status = self.training if force_eval: @@ -107,14 +116,17 @@ def internal_forward(self, force_eval): return out def retrieve_attribute(self, attribute, force_eval): - if self.is_quant_enabled: + if self._cached_act is not None: + return getattr(self._cached_act, attribute) + elif self.is_quant_enabled: out = self.internal_forward(force_eval) return getattr(out, attribute) - elif self._cached_act is not None: - return getattr(self._cached_act, attribute) elif self._cached_act is None: return None + def apply_input_view(self, x): + return self.input_view_impl(x) + @property def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant @@ -176,15 +188,18 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - # If the second value (i.e., scale) is None, then quant is disabled - if isinstance(y, tuple) and y[1] is not None: - out = self.create_quant_tensor(y) - elif self.is_passthrough_act and isinstance(x, QuantTensor): - # preserve quant_metadata - y = y[0] - out = self.create_quant_tensor(y, x=x) - else: + if is_dynamo_compiling(): out = y[0] + else: + # If the second value (i.e., scale) is None, then quant is disabled + if y[1] is not None: + out = self.create_quant_tensor(y) + elif self.is_passthrough_act and isinstance(x, QuantTensor): + # preserve scale/zp/bit/sign even without output quant + y = y[0] + out = self.create_quant_tensor(y, x=x) + else: + out = y[0] if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 7d846ce8d..0151c9232 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -190,7 +190,8 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}, 'po2_scale': { 'stats': { - 'per_group': MXInt8Act}}}}, + 'per_group': { + 'sym': MXInt8Act}}}}}, 'float': { 'static': { 'float_scale': { diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 58cc6563f..fd5e5c386 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -18,6 +18,7 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas.export.inference import quant_inference_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization @@ -267,6 +268,14 @@ def parse_type(v, default_type): 'uint_sym_act_for_unsigned_values', default=True, help='Use unsigned act quant when possible (default: enabled)') +add_bool_arg(parser, 'compile', default=False, help='Use torch.compile (default: disabled)') + + +def generate_ref_input(args, device, dtype): + model_config = get_model_config(args.model_name) + center_crop_shape = model_config['center_crop_shape'] + img_shape = center_crop_shape + return torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype) def main(): @@ -474,23 +483,28 @@ def main(): # Validate the quant_model on the validation dataloader print("Starting validation:") - validate(val_loader, quant_model, stable=dtype != torch.bfloat16) + with torch.no_grad(), quant_inference_mode(quant_model): + param = next(iter(quant_model.parameters())) + device, dtype = param.device, param.dtype + ref_input = generate_ref_input(args, device, dtype) + quant_model(ref_input) + compiled_model = torch.compile(quant_model, fullgraph=True, disable=not args.compile) + validate(val_loader, compiled_model, stable=dtype != torch.bfloat16) if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process - model_config = get_model_config(args.model_name) - center_crop_shape = model_config['center_crop_shape'] - img_shape = center_crop_shape - device, dtype = next(model.parameters()).device, next(model.parameters()).dtype - ref_input = torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype) + param = next(iter(quant_model.parameters())) + device, dtype = param.device, param.dtype + ref_input = generate_ref_input(args, device, dtype) export_name = os.path.join(args.export_dir, config) if args.export_onnx_qcdq: export_name = export_name + '.onnx' - export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version) + export_onnx_qcdq( + quant_model, ref_input, export_name, opset_version=args.onnx_opset_version) if args.export_torch_qcdq: export_name = export_name + '.pt' - export_torch_qcdq(model, ref_input, export_name) + export_torch_qcdq(quant_model, ref_input, export_name) if __name__ == '__main__': diff --git a/src/brevitas_examples/imagenet_classification/utils.py b/src/brevitas_examples/imagenet_classification/utils.py index d506b8a61..460e7d77f 100644 --- a/src/brevitas_examples/imagenet_classification/utils.py +++ b/src/brevitas_examples/imagenet_classification/utils.py @@ -1,5 +1,3 @@ -import csv - import torch import torchvision.datasets as datasets import torchvision.transforms as transforms diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index 0d76ae2db..09f0b9253 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -13,6 +13,7 @@ from brevitas import torch_version from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas.export.inference import quant_inference_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize @@ -21,9 +22,13 @@ from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model from tests.marker import requires_pt_ge +TORCH_COMPILE_ATOL = 0.35 BATCH = 1 HEIGHT, WIDTH = 224, 224 IN_CH = 3 + +COMPILE_MODEL_LIST = ['efficientnet_b0', 'resnet18', 'fcn_resnet50'] + MODEL_LIST = [ 'vit_b_32', 'efficientnet_b0', @@ -68,11 +73,7 @@ def quantize_float(model): quant_format='float') -@fixture -@parametrize('model_name', MODEL_LIST) -@parametrize('quantize_fn', [quantize, quantize_flexml, layerwise_quantize]) -def torchvision_model(model_name, quantize_fn): - +def shared_quant_fn(model_name, quantize_fn): inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) if torch_version <= version.parse('1.9.1') and model_name == 'regnet_x_400mf': @@ -112,20 +113,53 @@ def torchvision_model(model_name, quantize_fn): return model -@requires_pt_ge('1.8.1') +@fixture +@parametrize('model_name', MODEL_LIST) +@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml]) +def torchvision_model(model_name, quantize_fn): + return shared_quant_fn(model_name, quantize_fn) + + +@fixture +@parametrize('model_name', COMPILE_MODEL_LIST) +@parametrize('quantize_fn', [quantize_float, quantize]) +def torchvision_model_compile(model_name, quantize_fn): + return shared_quant_fn(model_name, quantize_fn) + + +@requires_pt_ge('2.2') +def test_torchvision_compile(torchvision_model_compile): + torch._dynamo.config.capture_scalar_outputs = True + if torchvision_model_compile is None: + pytest.skip('Model not instantiated') + + inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) + + with torch.no_grad(), quant_inference_mode(torchvision_model_compile): + prehook_non_compiled_out = torchvision_model_compile(inp) + post_hook_non_compiled_out = torchvision_model_compile(inp) + + compiled_model = torch.compile(torchvision_model_compile, fullgraph=True) + compiled_out = compiled_model(inp) + + assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL) + + def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, request): + test_id = request.node.callspec.id if torchvision_model is None: pytest.skip('Model not instantiated') + inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) - test_id = request.node.callspec.id quantize_fn_name = test_id.split("-")[0] torchvision_model(inp) + if quantize_fn_name != 'quantize_float': export_onnx_qcdq(torchvision_model, args=inp) -@requires_pt_ge('1.9.1') def test_torchvision_graph_quantization_flexml_qcdq_torch(torchvision_model, request): if torchvision_model is None: pytest.skip('Model not instantiated')