-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (torch.compile): prototype support for compiled inference (#1006)
- Loading branch information
Showing
20 changed files
with
461 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from .manager import InferenceManager | ||
from .manager import quant_inference_mode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from abc import ABC | ||
from abc import abstractmethod | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from brevitas.function.ops import max_float | ||
from brevitas.function.ops import max_int | ||
from brevitas.function.ops import min_int | ||
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector | ||
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector | ||
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase | ||
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector | ||
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector | ||
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector | ||
from brevitas.utils.torch_utils import float_internal_scale | ||
|
||
|
||
class InferenceHandler(torch.nn.Module, ABC): | ||
|
||
def attach_debug_info(self, module): | ||
pass | ||
|
||
@abstractmethod | ||
def prepare_for_export(self, module): | ||
pass | ||
|
||
@abstractmethod | ||
def quantize(self, x): | ||
pass | ||
|
||
@abstractmethod | ||
def dequantize(self, x): | ||
pass | ||
|
||
|
||
class IntInferencetHandler(InferenceHandler): | ||
handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector) | ||
|
||
def attach_debug_info(self, module): | ||
pass | ||
|
||
def prepare_for_export(self, module): | ||
if module.is_quant_enabled: | ||
self.scale = module.scale() | ||
self.zero_point = module.zero_point().to(self.scale.device) | ||
self.bit_width = module.bit_width() | ||
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) | ||
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) | ||
|
||
def quantize(self, x): | ||
return torch.clamp( | ||
torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp) | ||
|
||
def dequantize(self, x): | ||
return (x - self.zero_point) * self.scale | ||
|
||
def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: | ||
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width | ||
|
||
|
||
class IntWeightInferencetHandler(IntInferencetHandler): | ||
handled_layer = WeightQuantProxyFromInjector | ||
|
||
def prepare_for_export(self, module): | ||
if module.is_quant_enabled: | ||
self.cached_weight = None | ||
super().prepare_for_export(module) | ||
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: | ||
self.cached_weight = module._cached_weight.value | ||
|
||
def forward(self, x) -> Tuple[torch.Tensor]: | ||
if self.cached_weight is not None: | ||
x = self.cached_weight | ||
else: | ||
x = self.dequantize(self.quantize(x)) | ||
return x, self.scale, self.zero_point, self.bit_width | ||
|
||
|
||
class FloatInferencetHandler(InferenceHandler): | ||
handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector) | ||
|
||
def prepare_for_export(self, module): | ||
if module.is_quant_enabled: | ||
self.scale = module.scale() | ||
self.zero_point = module.zero_point().to(self.scale.device) | ||
self.exponent_bit_width = module.exponent_bit_width() | ||
self.mantissa_bit_width = module.mantissa_bit_width() | ||
self.exponent_bias = module.exponent_bias() | ||
self.saturating = module.is_saturating() | ||
self.inf_values = module.inf_values() | ||
self.nan_values = module.nan_values() | ||
self.eps = torch.finfo(self.scale.dtype).tiny | ||
if hasattr(module.tensor_quant, 'float_to_int_impl'): | ||
self.float_to_int_impl = module.tensor_quant.float_to_int_impl | ||
self.float_clamp_impl = module.tensor_quant.float_clamp_impl | ||
elif hasattr(module, 'fused_activation_quant_proxy'): | ||
self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl | ||
self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl | ||
|
||
self.max_clamp = max_float( | ||
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) | ||
self.min_clamp = -self.max_clamp | ||
self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width | ||
self.max_value = max_float( | ||
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) | ||
self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value | ||
|
||
def quantize(self, x): | ||
# Compute masks | ||
inf_mask = x.isinf() | ||
p_max_val_mask = x > self.max_value | ||
n_max_val_mask = -x > self.max_value | ||
|
||
# Quantize | ||
x = x / self.scale | ||
internal_scale = float_internal_scale( | ||
x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) | ||
x = internal_scale * self.float_to_int_impl(x / internal_scale) | ||
|
||
# Clamp | ||
x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value) | ||
if not self.saturating: | ||
x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask) | ||
|
||
return x | ||
|
||
def dequantize(self, x): | ||
return (x - self.zero_point) * self.scale | ||
|
||
def forward(self, x) -> Tuple[torch.Tensor]: | ||
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values | ||
|
||
|
||
class FloatWeightInferencetHandler(FloatInferencetHandler): | ||
handled_layer = WeightFloatQuantProxyFromInjector | ||
|
||
def prepare_for_export(self, module): | ||
if module.is_quant_enabled: | ||
self.cached_weight = None | ||
super().prepare_for_export(module) | ||
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: | ||
self.cached_weight = module._cached_weight.value | ||
|
||
def forward(self, x) -> Tuple[torch.Tensor]: | ||
if self.cached_weight is not None: | ||
x = self.cached_weight | ||
else: | ||
x = self.dequantize(self.quantize(x)) | ||
return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from torch.nn import Module | ||
import torch.nn as nn | ||
|
||
from brevitas.export.inference.handler import FloatInferencetHandler | ||
from brevitas.export.inference.handler import FloatWeightInferencetHandler | ||
from brevitas.export.inference.handler import IntInferencetHandler | ||
from brevitas.export.inference.handler import IntWeightInferencetHandler | ||
from brevitas.export.manager import _set_proxy_export_handler | ||
from brevitas.export.manager import _set_proxy_export_mode | ||
from brevitas.export.manager import _set_recurrent_layer_export_handler | ||
from brevitas.export.manager import _set_recurrent_layer_export_mode | ||
from brevitas.export.manager import BaseManager | ||
from brevitas.graph.calibrate import disable_return_quant_tensor | ||
from brevitas.graph.calibrate import restore_return_quant_tensor | ||
|
||
|
||
def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True): | ||
cache_var = 'cache_inference_quant_' + attr | ||
cache_var_metadata_only = cache_var + '_metadata_only' | ||
if hasattr(m, cache_var): | ||
setattr(m, cache_var, enabled) | ||
setattr(m, cache_var_metadata_only, metadata_only) | ||
|
||
|
||
def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): | ||
_override_caching_mode(m, 'bias', enabled, metadata_only) | ||
|
||
|
||
def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): | ||
_override_caching_mode(m, 'act', enabled, metadata_only) | ||
|
||
|
||
def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False): | ||
_override_caching_mode(m, 'weight', enabled, metadata_only) | ||
|
||
|
||
class quant_inference_mode: | ||
|
||
def __init__(self, model, cache_quant_weight=False, enabled=True): | ||
self.model = model | ||
self.enabled = enabled | ||
self.cache_quant_weight = cache_quant_weight | ||
self.export_manager = InferenceManager | ||
self.hook_list = [] | ||
self.return_quant_tensor_state = dict() | ||
|
||
def __enter__(self): | ||
if self.enabled: | ||
# Register the hook and store it in the list so that it can be removed by the hook itself when called | ||
handle = self.model.register_forward_hook(self.hook) | ||
self.hook_list.append(handle) | ||
|
||
# Enable bias for everything. Optionally, store the fully fake-quantized weights | ||
self.model.apply( | ||
lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True)) | ||
self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True)) | ||
self.model.apply( | ||
lambda m: _override_weight_caching_mode( | ||
m, enabled=True, metadata_only=not self.cache_quant_weight)) | ||
|
||
def __exit__(self, type, value, traceback): | ||
# Disable all caching | ||
# deactivate export mode | ||
# restore return quant tensor | ||
self.model.apply( | ||
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False)) | ||
self.model.apply( | ||
lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False)) | ||
if self.cache_quant_weight: | ||
self.model.apply( | ||
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) | ||
InferenceManager.set_export_mode(self.model, enabled=False) | ||
restore_return_quant_tensor(self.model, self.return_quant_tensor_state) | ||
|
||
def hook(self, module, inp, out): | ||
# After one forward pass with caching enabled, we can: | ||
# - Set the model in export mode | ||
# - Attach export handlers | ||
# - Disable return quant tensor since all quant metadata is cached | ||
assert len(self.hook_list) == 1 | ||
self.hook_list[0].remove() | ||
self.model.apply(InferenceManager.set_export_handler) | ||
InferenceManager.set_export_mode(self.model, enabled=True) | ||
self.return_quant_tensor_state = disable_return_quant_tensor(self.model) | ||
|
||
|
||
# Inheritance from BaseManager is not techincally needed | ||
class InferenceManager(BaseManager): | ||
handlers = [ | ||
IntInferencetHandler, | ||
FloatInferencetHandler, | ||
IntWeightInferencetHandler, | ||
FloatWeightInferencetHandler] | ||
|
||
@classmethod | ||
def set_export_mode(cls, model: Module, enabled: bool): | ||
_set_proxy_export_mode(model, enabled) | ||
_set_recurrent_layer_export_mode(model, enabled) | ||
|
||
@classmethod | ||
def set_export_handler(cls, module: Module): | ||
_set_proxy_export_handler(cls, module) | ||
_set_recurrent_layer_export_handler(cls, module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.