-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
130 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from brevitas.function.ops import max_int | ||
from brevitas.function.ops import min_int | ||
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector | ||
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector | ||
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector | ||
|
||
|
||
class IntInferencetHandler(torch.nn.Module): | ||
handled_layer = ( | ||
ActQuantProxyFromInjector, WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) | ||
|
||
def attach_debug_info(self, module): | ||
pass | ||
|
||
def prepare_for_export(self, module): | ||
if module.is_quant_enabled: | ||
self.scale = module.scale() | ||
self.zero_point = module.zero_point().to(self.scale.device) | ||
self.bit_width = module.bit_width() | ||
self.min_int = min_int(module.is_signed, module.is_narrow_range, self.bit_width) | ||
self.max_int = max_int(module.is_signed, module.is_narrow_range, self.bit_width) | ||
|
||
def quant(self, x): | ||
return torch.clamp( | ||
torch.round(x / self.scale + self.zero_point), self.min_int, self.max_int) | ||
|
||
def dequant(self, x): | ||
return (x - self.zero_point) * self.scale | ||
|
||
def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: | ||
return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from torch.nn import Module | ||
import torch.nn as nn | ||
|
||
from brevitas.export.inference.handler import IntInferencetHandler | ||
from brevitas.export.manager import _set_proxy_export_handler | ||
from brevitas.export.manager import _set_proxy_export_mode | ||
from brevitas.export.manager import _set_recurrent_layer_export_handler | ||
from brevitas.export.manager import _set_recurrent_layer_export_mode | ||
from brevitas.export.manager import BaseManager | ||
from brevitas.graph.calibrate import disable_return_quant_tensor | ||
from brevitas.graph.calibrate import restore_return_quant_tensor | ||
|
||
|
||
def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True): | ||
cache_var = 'cache_inference_quant_' + attr | ||
cache_var_metadata_only = cache_var + '_metadata_only' | ||
if hasattr(m, cache_var): | ||
setattr(m, cache_var, enabled) | ||
setattr(m, cache_var_metadata_only, metadata_only) | ||
|
||
|
||
def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): | ||
_override_caching_mode(m, 'bias', enabled, metadata_only) | ||
|
||
|
||
def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): | ||
_override_caching_mode(m, 'act', enabled, metadata_only) | ||
|
||
|
||
def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False): | ||
_override_caching_mode(m, 'weight', enabled, metadata_only) | ||
|
||
|
||
class inference_mode: | ||
|
||
def __init__(self, model, cache_quant_weight=False, enabled=True): | ||
self.model = model | ||
self.enabled = enabled | ||
self.cache_quant_weight = cache_quant_weight | ||
self.export_manager = InferenceManager | ||
self.hook_list = [] | ||
|
||
def __enter__(self): | ||
if self.enabled: | ||
# Register the hook and store it in the list so that it can be removed by the hook itself when called | ||
handle = self.model.register_forward_hook(self.hook) | ||
self.hook_list.append(handle) | ||
|
||
# Enable bias for everything. Optionally, store the fully fake-quantized weights | ||
self.model.apply( | ||
lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True)) | ||
self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True)) | ||
self.model.apply( | ||
lambda m: _override_weight_caching_mode( | ||
m, enabled=True, metadata_only=not self.cache_quant_weight)) | ||
|
||
def __exit__(self, type, value, traceback): | ||
# Disable all caching | ||
# deactivate export mode | ||
# restore return quant tensor | ||
self.model.apply( | ||
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False)) | ||
self.model.apply( | ||
lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False)) | ||
if self.cache_quant_weight: | ||
self.model.apply( | ||
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) | ||
InferenceManager.set_export_mode(self.model, enabled=False) | ||
restore_return_quant_tensor(self.model, self.return_quant_tensor_state) | ||
|
||
def hook(self, module, inp, out): | ||
# After one forward pass with caching enabled, we can: | ||
# - Set the model in export mode | ||
# - Attach export handlers | ||
# - Disable return quant tensor since all quant metadata is cached | ||
assert len(self.hook_list) == 1 | ||
self.hook_list[0].remove() | ||
self.model.apply(InferenceManager.set_export_handler) | ||
InferenceManager.set_export_mode(self.model, enabled=True) | ||
self.return_quant_tensor_state = disable_return_quant_tensor(self.model) | ||
|
||
|
||
# Inheritance from BaseManager is not techincally needed | ||
class InferenceManager(BaseManager): | ||
handlers = [IntInferencetHandler] | ||
|
||
@classmethod | ||
def set_export_mode(cls, model: Module, enabled: bool): | ||
_set_proxy_export_mode(model, enabled) | ||
_set_recurrent_layer_export_mode(model, enabled) | ||
|
||
@classmethod | ||
def set_export_handler(cls, module: Module): | ||
_set_proxy_export_handler(cls, module) | ||
_set_recurrent_layer_export_handler(cls, module) |