Skip to content

Commit

Permalink
Inference handler
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 7, 2024
1 parent eb9b9b4 commit d765704
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
35 changes: 35 additions & 0 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Tuple

import torch

from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector


class IntInferencetHandler(torch.nn.Module):
handled_layer = (
ActQuantProxyFromInjector, WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)

def attach_debug_info(self, module):
pass

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.bit_width = module.bit_width()
self.min_int = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
self.max_int = max_int(module.is_signed, module.is_narrow_range, self.bit_width)

def quant(self, x):
return torch.clamp(
torch.round(x / self.scale + self.zero_point), self.min_int, self.max_int)

def dequant(self, x):
return (x - self.zero_point) * self.scale

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width
95 changes: 95 additions & 0 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from torch.nn import Module
import torch.nn as nn

from brevitas.export.inference.handler import IntInferencetHandler
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import _set_recurrent_layer_export_handler
from brevitas.export.manager import _set_recurrent_layer_export_mode
from brevitas.export.manager import BaseManager
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor


def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True):
cache_var = 'cache_inference_quant_' + attr
cache_var_metadata_only = cache_var + '_metadata_only'
if hasattr(m, cache_var):
setattr(m, cache_var, enabled)
setattr(m, cache_var_metadata_only, metadata_only)


def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'bias', enabled, metadata_only)


def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'act', enabled, metadata_only)


def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False):
_override_caching_mode(m, 'weight', enabled, metadata_only)


class inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
self.model = model
self.enabled = enabled
self.cache_quant_weight = cache_quant_weight
self.export_manager = InferenceManager
self.hook_list = []

def __enter__(self):
if self.enabled:
# Register the hook and store it in the list so that it can be removed by the hook itself when called
handle = self.model.register_forward_hook(self.hook)
self.hook_list.append(handle)

# Enable bias for everything. Optionally, store the fully fake-quantized weights
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True))
self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True))
self.model.apply(
lambda m: _override_weight_caching_mode(
m, enabled=True, metadata_only=not self.cache_quant_weight))

def __exit__(self, type, value, traceback):
# Disable all caching
# deactivate export mode
# restore return quant tensor
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False))
self.model.apply(
lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False))
if self.cache_quant_weight:
self.model.apply(
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
InferenceManager.set_export_mode(self.model, enabled=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

def hook(self, module, inp, out):
# After one forward pass with caching enabled, we can:
# - Set the model in export mode
# - Attach export handlers
# - Disable return quant tensor since all quant metadata is cached
assert len(self.hook_list) == 1
self.hook_list[0].remove()
self.model.apply(InferenceManager.set_export_handler)
InferenceManager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)


# Inheritance from BaseManager is not techincally needed
class InferenceManager(BaseManager):
handlers = [IntInferencetHandler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
_set_proxy_export_mode(model, enabled)
_set_recurrent_layer_export_mode(model, enabled)

@classmethod
def set_export_handler(cls, module: Module):
_set_proxy_export_handler(cls, module)
_set_recurrent_layer_export_handler(cls, module)

0 comments on commit d765704

Please sign in to comment.