Skip to content

Commit

Permalink
Feat (torch.compile): prototype support for compiled inference (#1006)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Sep 23, 2024
1 parent 3a9bcc6 commit b28ac0f
Show file tree
Hide file tree
Showing 20 changed files with 461 additions and 78 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,4 @@ def tests_brevitas_end_to_end(session, pytorch):
install_pytorch(pytorch, session)
install_torchvision(pytorch, session)
session.install('--upgrade', '-e', '.[test, ort_integration]')
session.run('pytest', '-v', 'tests/brevitas_end_to_end')
session.run('pytest', '-n', 'logical', '-v', 'tests/brevitas_end_to_end')
6 changes: 6 additions & 0 deletions src/brevitas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
else:
torch_version = version.parse(torch.__version__)

try:
# Attempt _dynamo import
is_dynamo_compiling = torch._dynamo.is_compiling
except:
is_dynamo_compiling = lambda: False

try:
__version__ = get_distribution(__name__).version
except DistributionNotFound:
Expand Down
49 changes: 30 additions & 19 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,40 +113,51 @@ def __init__(
else:
self.max_available_float = None

def inf_nan_clamp(self, x, inf_mask, p_max_val_mask, n_max_val_mask):

# if non-saturating, we need to map values greater than max_val to nan or inf
if self.inf_values is not None:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
elif self.nan_values is not None:
# no inf values, so we need to map them to NaN
full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask)
x[full_max_val_mask] = torch.tensor(float('nan'))

# we also map the inf values to NaN in this case
x[inf_mask] = torch.tensor(float('nan'))
else:
raise RuntimeError(
"Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
)
return x

def saturating_clamp(self, x, max_value, min_value):
return self.tensor_clamp_impl(x, min_val=min_value, max_val=max_value)

@brevitas.jit.script_method
def forward(
self,
x: Tensor,
exponent_bit_width: Tensor,
mantissa_bit_width: Tensor,
exponent_bias: Tensor):
inf_mask = x.isinf()

max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
min_value = torch.tensor(0.) if not self.signed else -max_value

# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > max_value
n_max_val_mask = -x > max_value
min_float = torch.tensor(0.) if not self.signed else -max_value

# first clamp everything to +- max_value, basically the saturating case
x = self.tensor_clamp_impl(x, min_val=min_float, max_val=max_value)
x = self.saturating_clamp(x, max_value, min_value)

if not self.saturating:
# if non-saturating, we need to map values greater than max_val to nan or inf
if self.inf_values is not None:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
elif self.nan_values is not None:
# no inf values, so we need to map them to NaN
full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask)
x[full_max_val_mask] = torch.tensor(float('nan'))

# we also map the inf values to NaN in this case
x[inf_mask] = torch.tensor(float('nan'))
else:
raise RuntimeError(
"Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
)
x = self.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask)

return x, self.saturating, self.inf_values, self.nan_values
4 changes: 2 additions & 2 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def prepare_for_export(self, module):
self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width()
self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width()
self.symbolic_kwargs['exponent_bias'] = module.exponent_bias()
self.symbolic_kwargs['saturating'] = module.saturating()
self.symbolic_kwargs['saturating'] = module.is_saturating()
self.symbolic_kwargs['inf_values'] = module.inf_values()
self.symbolic_kwargs['nan_values'] = module.nan_values()

Expand Down Expand Up @@ -659,7 +659,7 @@ def prepare_for_export(self, module):
'exponent_bit_width': module.exponent_bit_width(),
'mantissa_bit_width': module.mantissa_bit_width(),
'exponent_bias': module.exponent_bias(),
'saturating': module.saturating(),
'saturating': module.is_saturating(),
'inf_values': module.inf_values(),
'nan_values': module.nan_values()}

Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/export/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from .manager import InferenceManager
from .manager import quant_inference_mode
153 changes: 153 additions & 0 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from abc import abstractmethod
from typing import Tuple

import torch

from brevitas.function.ops import max_float
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.utils.torch_utils import float_internal_scale


class InferenceHandler(torch.nn.Module, ABC):

def attach_debug_info(self, module):
pass

@abstractmethod
def prepare_for_export(self, module):
pass

@abstractmethod
def quantize(self, x):
pass

@abstractmethod
def dequantize(self, x):
pass


class IntInferencetHandler(InferenceHandler):
handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector)

def attach_debug_info(self, module):
pass

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.bit_width = module.bit_width()
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width)

def quantize(self, x):
return torch.clamp(
torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp)

def dequantize(self, x):
return (x - self.zero_point) * self.scale

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width


class IntWeightInferencetHandler(IntInferencetHandler):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.cached_weight = None
super().prepare_for_export(module)
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.value

def forward(self, x) -> Tuple[torch.Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
return x, self.scale, self.zero_point, self.bit_width


class FloatInferencetHandler(InferenceHandler):
handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector)

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.exponent_bit_width = module.exponent_bit_width()
self.mantissa_bit_width = module.mantissa_bit_width()
self.exponent_bias = module.exponent_bias()
self.saturating = module.is_saturating()
self.inf_values = module.inf_values()
self.nan_values = module.nan_values()
self.eps = torch.finfo(self.scale.dtype).tiny
if hasattr(module.tensor_quant, 'float_to_int_impl'):
self.float_to_int_impl = module.tensor_quant.float_to_int_impl
self.float_clamp_impl = module.tensor_quant.float_clamp_impl
elif hasattr(module, 'fused_activation_quant_proxy'):
self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl
self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl

self.max_clamp = max_float(
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_clamp = -self.max_clamp
self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width
self.max_value = max_float(
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value

def quantize(self, x):
# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value
n_max_val_mask = -x > self.max_value

# Quantize
x = x / self.scale
internal_scale = float_internal_scale(
x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps)
x = internal_scale * self.float_to_int_impl(x / internal_scale)

# Clamp
x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value)
if not self.saturating:
x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask)

return x

def dequantize(self, x):
return (x - self.zero_point) * self.scale

def forward(self, x) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class FloatWeightInferencetHandler(FloatInferencetHandler):
handled_layer = WeightFloatQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.cached_weight = None
super().prepare_for_export(module)
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.value

def forward(self, x) -> Tuple[torch.Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
106 changes: 106 additions & 0 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from torch.nn import Module
import torch.nn as nn

from brevitas.export.inference.handler import FloatInferencetHandler
from brevitas.export.inference.handler import FloatWeightInferencetHandler
from brevitas.export.inference.handler import IntInferencetHandler
from brevitas.export.inference.handler import IntWeightInferencetHandler
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import _set_recurrent_layer_export_handler
from brevitas.export.manager import _set_recurrent_layer_export_mode
from brevitas.export.manager import BaseManager
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor


def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True):
cache_var = 'cache_inference_quant_' + attr
cache_var_metadata_only = cache_var + '_metadata_only'
if hasattr(m, cache_var):
setattr(m, cache_var, enabled)
setattr(m, cache_var_metadata_only, metadata_only)


def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'bias', enabled, metadata_only)


def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'act', enabled, metadata_only)


def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False):
_override_caching_mode(m, 'weight', enabled, metadata_only)


class quant_inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
self.model = model
self.enabled = enabled
self.cache_quant_weight = cache_quant_weight
self.export_manager = InferenceManager
self.hook_list = []
self.return_quant_tensor_state = dict()

def __enter__(self):
if self.enabled:
# Register the hook and store it in the list so that it can be removed by the hook itself when called
handle = self.model.register_forward_hook(self.hook)
self.hook_list.append(handle)

# Enable bias for everything. Optionally, store the fully fake-quantized weights
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True))
self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True))
self.model.apply(
lambda m: _override_weight_caching_mode(
m, enabled=True, metadata_only=not self.cache_quant_weight))

def __exit__(self, type, value, traceback):
# Disable all caching
# deactivate export mode
# restore return quant tensor
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False))
self.model.apply(
lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False))
if self.cache_quant_weight:
self.model.apply(
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
InferenceManager.set_export_mode(self.model, enabled=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

def hook(self, module, inp, out):
# After one forward pass with caching enabled, we can:
# - Set the model in export mode
# - Attach export handlers
# - Disable return quant tensor since all quant metadata is cached
assert len(self.hook_list) == 1
self.hook_list[0].remove()
self.model.apply(InferenceManager.set_export_handler)
InferenceManager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)


# Inheritance from BaseManager is not techincally needed
class InferenceManager(BaseManager):
handlers = [
IntInferencetHandler,
FloatInferencetHandler,
IntWeightInferencetHandler,
FloatWeightInferencetHandler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
_set_proxy_export_mode(model, enabled)
_set_recurrent_layer_export_mode(model, enabled)

@classmethod
def set_export_handler(cls, module: Module):
_set_proxy_export_handler(cls, module)
_set_recurrent_layer_export_handler(cls, module)
8 changes: 6 additions & 2 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,15 @@ def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs):
@classmethod
def handler_from_module(cls, module: Module, no_inheritance=False):
for handler in cls.handlers:
if not isinstance(handler.handled_layer, tuple):
handled_classes = (handler.handled_layer,)
else:
handled_classes = handler.handled_layer
if no_inheritance:
if type(module) == handler.handled_layer:
if type(module) in handled_classes:
return handler
else:
if isinstance(module, handler.handled_layer):
if any([isinstance(module, handler) for handler in handled_classes]):
return handler
return None

Expand Down
Loading

0 comments on commit b28ac0f

Please sign in to comment.