Skip to content

Commit

Permalink
Compile stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 22, 2024
1 parent 6733ba2 commit f56332f
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 250 deletions.
40 changes: 40 additions & 0 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,46 @@ def __exit__(self, type, value, traceback):
self.model, is_training=self.previous_training_state, quantization_enabled=True)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=True):
if hasattr(m, 'cache_inference_quant_bias'):
if not hasattr(m, "cache_inference_quant_bias_backup"):
m.cache_inference_quant_bias_backup = m.cache_inference_quant_bias
m.cache_inference_quant_bias = enabled
m.cache_inference_quant_bias_metadata_only = metadata_only


def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=True):
if hasattr(m, 'cache_inference_quant_act'):
if not hasattr(m, "cache_inference_quant_act_backup"):
m.cache_inference_quant_act_backup = m.cache_inference_quant_act
m.cache_inference_quant_act = enabled
m.cache_inference_quant_act_metadata_only = metadata_only

def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only:bool=False):
if hasattr(m, 'cache_inference_quant_weight'):
if not hasattr(m, "cache_inference_quant_weight_backup"):
m.cache_inference_quant_weight_backup = m.cache_inference_quant_weight
m.cache_inference_quant_weight = enabled
m.cache_inference_quant_weight_metadata_only = metadata_only


class inference_mode:
def __init__(self, model, cache_quant_weight=False, enabled=True):
self.model = model
self.enabled=enabled
self.cache_quant_weight = cache_quant_weight

def __enter__(self):
if self.enabled:
self.model.apply(lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True))
self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True))
if self.cache_quant_weight:
self.model.apply(lambda m: _override_weight_caching_mode(m, enabled=True))


def __exit__(self, type, value, traceback):
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)


class bias_correction_mode:

Expand Down
6 changes: 1 addition & 5 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def __init__(self, return_quant_tensor: bool):
def channelwise_separable(self) -> bool:
pass

def _set_global_is_quant_layer(self, value):
config._IS_INSIDE_QUANT_LAYER = value

def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]):
quant_tensor_classes = [
Expand All @@ -81,23 +79,21 @@ def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]):
return None

def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(True)
# Hack to recognize a QuantTensor that has decayed to a tuple
# when used as input to tracing (e.g. during ONNX export)
if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and
all([isinstance(t, Tensor) for t in inp])):
qt_class = self.get_quant_tensor_class(inp)
if qt_class is not None:
inp = qt_class(*inp)
if not torch._C._get_tracing_state():
if not torch._C._get_tracing_state() and not torch.compiler.is_compiling():
if isinstance(inp, QuantTensor):
inp = inp.set(value=inp.value.rename(None))
else:
inp = inp.rename(None)
return inp

def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(False)
if self.return_quant_tensor:
assert isinstance(quant_output, QuantTensor), 'QuantLayer is not correctly configured, check if warnings were raised'
return quant_output
Expand Down
60 changes: 25 additions & 35 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional, Union
from warnings import warn
from brevitas.quant_tensor import _unpack_quant_tensor

import torch
from torch import Tensor
Expand Down Expand Up @@ -84,46 +84,36 @@ def is_fnuz(self):
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return FloatQuantTensor(
out,
scale,
zero_point,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
else: # quantization disabled
return x


class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return FloatQuantTensor(
out,
scale,
zero_point,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only:
out = self._cached_weight.quant_tensor
if torch.compiler.is_compiling():
out = _unpack_quant_tensor(out)
else:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
if not torch.compiler.is_compiling():
out = FloatQuantTensor(
out,
scale,
zero_point,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = _CachedIOFloat(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
return x
out = x
return out


class BiasFloatQuantProxyFromInjector(BiasQuantProxyFromInjectorBase):
Expand Down
53 changes: 28 additions & 25 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,33 +74,36 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTens
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, FloatQuantTensor):
out = FloatQuantTensor(
y,
x.scale,
x.zero_point,
x.exponent_bit_width,
x.mantissa_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
if torch.compiler.is_compiling():
y = y[0]
else:
# If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, FloatQuantTensor):
out = FloatQuantTensor(
y,
x.scale,
x.zero_point,
x.exponent_bit_width,
x.mantissa_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
if isinstance(y, tuple):
y = y[0]
out = y
else:
if isinstance(y, tuple):
y = y[0]
out = y
else:
# If fused activation quant proxy is not enabled, return the input
out = x
Expand Down
48 changes: 29 additions & 19 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from torch import Tensor

from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
from brevitas.quant_tensor import GroupwiseFloatQuantTensor, _unpack_quant_tensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat


class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):
Expand All @@ -24,22 +25,31 @@ def group_size(self):

def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return GroupwiseFloatQuantTensor(
out,
scale,
zero_point,
self.group_size,
self.group_dim,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only:
out = self._cached_weight.quant_tensor
if torch.compiler.is_compiling():
out = _unpack_quant_tensor(out)
else:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
if not torch.compiler.is_compiling():
out = GroupwiseFloatQuantTensor(
out,
scale,
zero_point,
self.group_size,
self.group_dim,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = _CachedIOGroupwiseFloat(out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
return x
out = x
return out
74 changes: 39 additions & 35 deletions src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat
import torch


class GroupwiseActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase):
Expand Down Expand Up @@ -35,46 +36,49 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat
# If y is an empty GroupwiseFloatQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y
out = GroupwiseFloatQuantTensor(
value,
scale,
zero_point,
self.group_size,
self.group_dim,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
signed=self.is_signed,
training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, GroupwiseFloatQuantTensor):
if torch.compiler.is_compiling():
y = y[0]
else:
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y
out = GroupwiseFloatQuantTensor(
y,
x.scale,
x.zero_point,
value,
scale,
zero_point,
self.group_size,
self.group_dim,
x.exponent_bit_width,
x.mantissa_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
signed=self.is_signed,
training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, GroupwiseFloatQuantTensor):
out = GroupwiseFloatQuantTensor(
y,
x.scale,
x.zero_point,
self.group_size,
self.group_dim,
x.exponent_bit_width,
x.mantissa_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
if isinstance(y, tuple):
y = y[0]
out = y
else:
if isinstance(y, tuple):
y = y[0]
out = y
else:
# If fused activation quant proxy is not enabled, return the input
out = x
Expand Down
Loading

0 comments on commit f56332f

Please sign in to comment.