Skip to content

Commit

Permalink
simplification & compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 25, 2024
1 parent f56332f commit d9a29b1
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 405 deletions.
17 changes: 10 additions & 7 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,24 @@ def __exit__(self, type, value, traceback):
self.model, is_training=self.previous_training_state, quantization_enabled=True)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=True):

def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
if hasattr(m, 'cache_inference_quant_bias'):
if not hasattr(m, "cache_inference_quant_bias_backup"):
m.cache_inference_quant_bias_backup = m.cache_inference_quant_bias
m.cache_inference_quant_bias = enabled
m.cache_inference_quant_bias_metadata_only = metadata_only


def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=True):
def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
if hasattr(m, 'cache_inference_quant_act'):
if not hasattr(m, "cache_inference_quant_act_backup"):
m.cache_inference_quant_act_backup = m.cache_inference_quant_act
m.cache_inference_quant_act = enabled
m.cache_inference_quant_act_metadata_only = metadata_only

def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=False):

def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False):
if hasattr(m, 'cache_inference_quant_weight'):
if not hasattr(m, "cache_inference_quant_weight_backup"):
m.cache_inference_quant_weight_backup = m.cache_inference_quant_weight
Expand All @@ -129,19 +131,20 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only:boo


class inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
self.model = model
self.enabled=enabled
self.enabled = enabled
self.cache_quant_weight = cache_quant_weight

def __enter__(self):
if self.enabled:
self.model.apply(lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True))
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True))
self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True))
if self.cache_quant_weight:
self.model.apply(lambda m: _override_weight_caching_mode(m, enabled=True))


def __exit__(self, type, value, traceback):
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)

Expand Down
1 change: 0 additions & 1 deletion src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(self, return_quant_tensor: bool):
def channelwise_separable(self) -> bool:
pass


def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]):
quant_tensor_classes = [
IntQuantTensor, FloatQuantTensor, GroupwiseIntQuantTensor, GroupwiseFloatQuantTensor]
Expand Down
46 changes: 17 additions & 29 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, Union
from brevitas.quant_tensor import _unpack_quant_tensor
from typing import Any, List, Optional, Union

import torch
from torch import Tensor
Expand All @@ -8,7 +7,9 @@
from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIOFloat


Expand Down Expand Up @@ -87,33 +88,20 @@ def is_fnuz(self):

class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]:
if self.is_quant_enabled:
if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only:
out = self._cached_weight.quant_tensor
if torch.compiler.is_compiling():
out = _unpack_quant_tensor(out)
else:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
if not torch.compiler.is_compiling():
out = FloatQuantTensor(
out,
scale,
zero_point,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = _CachedIOFloat(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = x
return out
def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, QuantTensor]:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return FloatQuantTensor(
out,
scale,
zero_point,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)


class BiasFloatQuantProxyFromInjector(BiasQuantProxyFromInjectorBase):
Expand Down
119 changes: 19 additions & 100 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,108 +60,27 @@ def is_fnuz(self):
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
elif not self.is_quant_enabled:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
if torch.compiler.is_compiling():
y = y[0]
else:
# If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, FloatQuantTensor):
out = FloatQuantTensor(
y,
x.scale,
x.zero_point,
x.exponent_bit_width,
x.mantissa_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
if isinstance(y, tuple):
y = y[0]
out = y
else:
# If fused activation quant proxy is not enabled, return the input
out = x
if not self.training and self.cache_inference_quant_act and isinstance(out,
FloatQuantTensor):
cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only)
self._cached_act = cached_out
return out


class ActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase):

def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, FloatQuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
elif not self.is_quant_enabled:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, FloatQuantTensor):
out = FloatQuantTensor(
y,
x.scale,
x.zero_point,
x.mantissa_bit_width,
x.exponent_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
if isinstance(y, tuple):
y = y[0]
out = y
def __init__(self, quant_layer, quant_injector):
super().__init__(self, quant_layer, quant_injector)
self.cache_class = _CachedIOFloat

def create_quant_tensor(self, *qt_args, x=None):
if x is None:
out = FloatQuantTensor(qt_args, signed=self.is_signed, training=self.training)
else:
# If fused activation quant proxy is not enabled, return the input
out = x
if not self.training and self.cache_inference_quant_act and isinstance(out,
FloatQuantTensor):
cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only)
self._cached_act = cached_out
out = FloatQuantTensor(
qt_args,
x.scale,
x.zero_point,
x.mantissa_bit_width,
x.exponent_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
return out
52 changes: 18 additions & 34 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Union
from typing import Any, List, Union

import torch
from torch import Tensor

from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase
from brevitas.quant_tensor import GroupwiseFloatQuantTensor, _unpack_quant_tensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat
from brevitas.quant_tensor import GroupwiseFloatQuantTensor


class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):
Expand All @@ -23,33 +21,19 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]:
if self.is_quant_enabled:
if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only:
out = self._cached_weight.quant_tensor
if torch.compiler.is_compiling():
out = _unpack_quant_tensor(out)
else:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
if not torch.compiler.is_compiling():
out = GroupwiseFloatQuantTensor(
out,
scale,
zero_point,
self.group_size,
self.group_dim,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = _CachedIOGroupwiseFloat(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = x
return out
def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseFloatQuantTensor]:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return GroupwiseFloatQuantTensor(
out,
scale,
zero_point,
self.group_size,
self.group_dim,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
Loading

0 comments on commit d9a29b1

Please sign in to comment.