From d7657047af782b93e6c3e419b2072f4ec78930d5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 7 Sep 2024 11:16:57 +0100 Subject: [PATCH] Inference handler --- src/brevitas/export/inference/handler.py | 35 +++++++++ src/brevitas/export/inference/manager.py | 95 ++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 src/brevitas/export/inference/handler.py create mode 100644 src/brevitas/export/inference/manager.py diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py new file mode 100644 index 000000000..6ff5bf258 --- /dev/null +++ b/src/brevitas/export/inference/handler.py @@ -0,0 +1,35 @@ +from typing import Tuple + +import torch + +from brevitas.function.ops import max_int +from brevitas.function.ops import min_int +from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector + + +class IntInferencetHandler(torch.nn.Module): + handled_layer = ( + ActQuantProxyFromInjector, WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) + + def attach_debug_info(self, module): + pass + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.scale = module.scale() + self.zero_point = module.zero_point().to(self.scale.device) + self.bit_width = module.bit_width() + self.min_int = min_int(module.is_signed, module.is_narrow_range, self.bit_width) + self.max_int = max_int(module.is_signed, module.is_narrow_range, self.bit_width) + + def quant(self, x): + return torch.clamp( + torch.round(x / self.scale + self.zero_point), self.min_int, self.max_int) + + def dequant(self, x): + return (x - self.zero_point) * self.scale + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py new file mode 100644 index 000000000..b1db6501d --- /dev/null +++ b/src/brevitas/export/inference/manager.py @@ -0,0 +1,95 @@ +from torch.nn import Module +import torch.nn as nn + +from brevitas.export.inference.handler import IntInferencetHandler +from brevitas.export.manager import _set_proxy_export_handler +from brevitas.export.manager import _set_proxy_export_mode +from brevitas.export.manager import _set_recurrent_layer_export_handler +from brevitas.export.manager import _set_recurrent_layer_export_mode +from brevitas.export.manager import BaseManager +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import restore_return_quant_tensor + + +def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True): + cache_var = 'cache_inference_quant_' + attr + cache_var_metadata_only = cache_var + '_metadata_only' + if hasattr(m, cache_var): + setattr(m, cache_var, enabled) + setattr(m, cache_var_metadata_only, metadata_only) + + +def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): + _override_caching_mode(m, 'bias', enabled, metadata_only) + + +def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): + _override_caching_mode(m, 'act', enabled, metadata_only) + + +def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False): + _override_caching_mode(m, 'weight', enabled, metadata_only) + + +class inference_mode: + + def __init__(self, model, cache_quant_weight=False, enabled=True): + self.model = model + self.enabled = enabled + self.cache_quant_weight = cache_quant_weight + self.export_manager = InferenceManager + self.hook_list = [] + + def __enter__(self): + if self.enabled: + # Register the hook and store it in the list so that it can be removed by the hook itself when called + handle = self.model.register_forward_hook(self.hook) + self.hook_list.append(handle) + + # Enable bias for everything. Optionally, store the fully fake-quantized weights + self.model.apply( + lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True)) + self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True)) + self.model.apply( + lambda m: _override_weight_caching_mode( + m, enabled=True, metadata_only=not self.cache_quant_weight)) + + def __exit__(self, type, value, traceback): + # Disable all caching + # deactivate export mode + # restore return quant tensor + self.model.apply( + lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False)) + self.model.apply( + lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False)) + if self.cache_quant_weight: + self.model.apply( + lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) + InferenceManager.set_export_mode(self.model, enabled=False) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + + def hook(self, module, inp, out): + # After one forward pass with caching enabled, we can: + # - Set the model in export mode + # - Attach export handlers + # - Disable return quant tensor since all quant metadata is cached + assert len(self.hook_list) == 1 + self.hook_list[0].remove() + self.model.apply(InferenceManager.set_export_handler) + InferenceManager.set_export_mode(self.model, enabled=True) + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + + +# Inheritance from BaseManager is not techincally needed +class InferenceManager(BaseManager): + handlers = [IntInferencetHandler] + + @classmethod + def set_export_mode(cls, model: Module, enabled: bool): + _set_proxy_export_mode(model, enabled) + _set_recurrent_layer_export_mode(model, enabled) + + @classmethod + def set_export_handler(cls, module: Module): + _set_proxy_export_handler(cls, module) + _set_recurrent_layer_export_handler(cls, module)