Skip to content

Commit

Permalink
Fix (mx): input view during quantization (#1005)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Aug 24, 2024
1 parent 6733ba2 commit 89fca2f
Show file tree
Hide file tree
Showing 20 changed files with 107 additions and 40 deletions.
8 changes: 6 additions & 2 deletions notebooks/Brevitas_TVMCon2021.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1659,11 +1659,13 @@
"from brevitas.core.bit_width import BitWidthConst\n",
"from brevitas.core.quant import IntQuant, RescalingIntQuant\n",
"from brevitas.core.zero_point import ZeroZeroPoint\n",
"from brevitas.core.function_wrapper.misc import Identity\n",
"\n",
"tensor_quant = RescalingIntQuant(\n",
" int_quant=IntQuant(\n",
" float_to_int_impl=RoundSte(),\n",
" tensor_clamp_impl=TensorClamp(),\n",
" input_view_impl=Identity,\n",
" signed=False,\n",
" narrow_range=False),\n",
" zero_point_impl=ZeroZeroPoint(),\n",
Expand Down Expand Up @@ -1767,6 +1769,7 @@
"from brevitas.inject import value\n",
"from brevitas.proxy import WeightQuantProxyFromInjector\n",
"from brevitas.core.scaling import ParameterScaling\n",
"from brevitas.core.function_wrapper.misc import Identity\n",
"\n",
"class Int8ActPerTensorFloatParameterFromScratch(ExtendedInjector):\n",
" \n",
Expand All @@ -1784,11 +1787,12 @@
" int_scaling_impl = IntScaling\n",
" scaling_impl = ParameterScaling\n",
" restrict_scaling_impl = FloatRestrictValue\n",
" input_view_impl = Identity\n",
" scaling_shape = ()\n",
" bit_width = 8\n",
" narrow_range = True\n",
" signed = True\n",
" \n",
"\n",
"quant_linear = QuantLinear(2, 4, weight_quant=Int8ActPerTensorFloatParameterFromScratch, bias=False)"
]
},
Expand Down Expand Up @@ -1936,7 +1940,7 @@
"torch.manual_seed(0)\n",
"\n",
"from brevitas.export import export_qonnx\n",
"from brevitas.quant import Int8WeightPerTensorFloat, Int8ActPerTensorFloat, Int16Bias\n",
"from brevitas.quant import Int8ActPerTensorFloat, Int16Bias\n",
"\n",
"float_inp = torch.randn(1, 2, 5)\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,20 +233,20 @@
"import brevitas.nn as qnn\n",
"import torch\n",
"\n",
"class MXFloat8Weight(MXInt8Weight):\n",
"class MXInt8Weight(MXInt8Weight):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" bit_width = 8\n",
"\n",
"class MXFloat8Act(MXInt8Act):\n",
"class MXInt8Act(MXInt8Act):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" bit_width = 8\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXInt8Weight, input_quant=MXInt8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
Expand Down
20 changes: 20 additions & 0 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,25 @@ def forward(self, x: torch.Tensor):
return y


class DynamicOverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['group_size', 'group_dim']

def __init__(self, group_size, group_dim) -> None:
super(DynamicOverSubChannelBlockView, self).__init__()
self.group_size = group_size
self.group_dim = group_dim

@brevitas.jit.script_method
def forward(self, x):
tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list.insert(block_dim, self.group_size)
x = x.view(tensor_shape_list)
return x


class StatsInputViewShapeImpl(object):
"""
Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor.
Expand All @@ -182,3 +201,4 @@ class StatsInputViewShapeImpl(object):
OVER_BATCH_OVER_OUTPUT_CHANNELS = OverBatchOverOutputChannelView
OVER_OUTPUT_FEATURES = OverOutputFeaturesView
OVER_SUBCHANNEL_BLOCK = OverSubChannelBlockView
DYNAMIC_OVER_SUBCHANNEL_BLOCK = DynamicOverSubChannelBlockView
4 changes: 3 additions & 1 deletion src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
mantissa_bit_width: int,
exponent_bias: int,
float_clamp_impl: nn.Module,
input_view_impl: nn.Module,
scaling_impl: Optional[nn.Module] = None,
float_scaling_impl: Optional[nn.Module] = None,
float_to_int_impl: nn.Module = RoundSte(),
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
if scaling_impl is None:
scaling_impl = ConstScaling(1., device=device, dtype=dtype)

self.input_view_impl = input_view_impl
# Zero-point is currently hardcoded to 0
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype))
self.float_scaling_impl = float_scaling_impl
Expand All @@ -71,7 +73,7 @@ def quantize(self, x: torch.Tensor):
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale / float_scaling_impl_value

x = self.input_view_impl(x)
scaled_x = x / scale
internal_scale = float_internal_scale(
scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
Expand All @@ -60,9 +61,11 @@ def __init__(
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor:
x = self.input_view_impl(x)
y = x / scale
y = y + zero_point
min_int_val = self.min_int(bit_width)
Expand Down Expand Up @@ -124,6 +127,7 @@ def __init__(
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
Expand All @@ -133,11 +137,13 @@ def __init__(
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
def to_int(
self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor,
x: Tensor) -> Tensor:
x = self.input_view_impl(x)
y = x / pre_scale
y = y + pre_zero_point
min_int_val = self.min_int(bit_width)
Expand Down
14 changes: 3 additions & 11 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: torch.nn.Module,
scaling_stats_impl: torch.nn.Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Optional[torch.nn.Module]) -> None:
Expand All @@ -174,21 +175,12 @@ def __init__(
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl
self.scaling_min_val = scaling_min_val
self.input_view_impl = input_view_impl
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def group_scaling_reshape(self, stats_input):
tensor_shape = stats_input.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list.insert(block_dim, self.group_size)
stats_input = stats_input.view(tensor_shape_list)
return stats_input

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
stats_input_reshaped = self.group_scaling_reshape(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
# Scaling min val
out = self.restrict_clamp_scaling(out)
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/export/onnx/qonnx/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from brevitas.core.bit_width import BitWidthConst
from brevitas.core.function_wrapper.clamp import TensorClamp
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant import IntQuant
from brevitas.core.quant import TruncIntQuant
from brevitas.function import binary_sign
Expand Down Expand Up @@ -51,6 +52,7 @@ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding
quant = IntQuant(
float_to_int_impl=float_to_int_impl(),
tensor_clamp_impl=TensorClamp(),
input_view_impl=Identity(), #TODO: Update this when QONNX support Groupwise export
narrow_range=narrow_range,
signed=signed)
y = quant(scale, zero_point, bit_width, x)
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@

class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Is this always generated?
self.view_impl = self.quant_injector.scaling_stats_input_view_shape_impl

@property
def group_dim(self):
return self.quant_injector.group_dim
Expand All @@ -25,7 +20,6 @@ def group_size(self):
def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return GroupwiseFloatQuantTensor(
out,
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat
y = x
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def group_size(self):
def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, bit_width = impl(x)
return GroupwiseIntQuantTensor(
out,
Expand Down
10 changes: 8 additions & 2 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from brevitas.core.bit_width import BitWidthConst
from brevitas.core.bit_width import BitWidthStatefulConst
from brevitas.core.function_wrapper import Identity
from brevitas.core.function_wrapper import OverOutputChannelView
from brevitas.core.function_wrapper import RoundToZeroSte
from brevitas.core.function_wrapper import TensorClamp
Expand Down Expand Up @@ -53,6 +54,7 @@
from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector
from brevitas.quant.solver.common import SolveStatsReduceDimFromEnum
from brevitas.quant.solver.parameter import SolveInputViewImpl
from brevitas.quant.solver.parameter import SolveParameterScalingShape
from brevitas.quant.solver.weight import SolveWeightScalingPerOutputChannelShapeFromModule
from brevitas.quant.solver.weight import SolveWeightScalingStatsInputDimsFromModule
Expand Down Expand Up @@ -293,6 +295,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
restrict_scaling_impl = FloatRestrictValue
scaling_shape = SCALAR_SHAPE
scaling_per_output_type = ScalingPerOutputType.TENSOR
input_view_impl = Identity
scaling_impl = ParameterFromStatsFromParameterScaling
int_scaling_impl = IntScaling
zero_point_impl = ZeroZeroPoint
Expand All @@ -305,7 +309,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM
class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveParameterScalingShape):
SolveParameterScalingShape,
SolveInputViewImpl):
"""
Experimental narrow per-channel signed int weight quantizer fragment with decoupled Linf
normalization and learned scaling.
Expand Down Expand Up @@ -333,7 +338,8 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveParameterScalingShape):
SolveParameterScalingShape,
SolveInputViewImpl):
"""Experimental narrow per-channel weight normalization-based signed integer quantizer
based on `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed
Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig.
Expand Down
15 changes: 14 additions & 1 deletion src/brevitas/quant/solver/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch import nn
from torch import Tensor

from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl
from brevitas.core.quant import ClampedBinaryQuant
from brevitas.core.quant import RescalingIntQuant
from brevitas.core.quant import TernaryQuant
Expand Down Expand Up @@ -128,6 +130,16 @@ def update_state_dict_impl(scaling_impl_type):
return None


class SolveInputViewImpl(ExtendedInjector):

@value
def input_view_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.DYNAMIC_OVER_SUBCHANNEL_BLOCK
else:
return Identity


class ActQuantSolver(SolveActTensorQuantFromEnum,
SolveActScalingImplFromEnum,
SolveIntScalingImplFromEnum,
Expand All @@ -140,7 +152,8 @@ class ActQuantSolver(SolveActTensorQuantFromEnum,
SolveActScalingShape,
SolveScalingStatsInputViewShapeImplFromEnum,
SolveActScalingPerOutputChannelShape,
SolveUpdateStateDictImplFromEnum):
SolveUpdateStateDictImplFromEnum,
SolveInputViewImpl):
"""
Translate enum directives to activation-specific quantization core modules.
It should be placed last in the list of classes a quantizer inherits from,
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/quant/solver/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from brevitas.proxy import BiasQuantProxyFromInjector
from brevitas.quant.solver.common import *
from brevitas.quant.solver.parameter import *
from brevitas.quant.solver.parameter import SolveInputViewImpl

__all__ = [
'BiasQuantSolver',
Expand Down Expand Up @@ -65,7 +66,8 @@ class BiasQuantSolver(SolveScalingStatsInputViewShapeImplFromEnum,
SolveBiasScalingPerOutputChannelShapeFromModule,
SolveBiasScalingStatsInputConcatDimFromModule,
SolveBiasTensorQuantFromEnum,
SolveDtypeDeviceFromTrackedParameterList):
SolveDtypeDeviceFromTrackedParameterList,
SolveInputViewImpl):
"""
Translate enum directives to bias-specific quantization core modules.
It should be placed last in the list of classes a quantizer inherits from,
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/quant/solver/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ def scaling_per_output(scaling_per_output_type=None, scaling_per_output_channel=
return ScalingPerOutputType.CHANNEL if scaling_per_output_channel else ScalingPerOutputType.TENSOR
elif scaling_per_output_type is not None:
return scaling_per_output_type
else:
raise RuntimeError("Specify scaling_per_output_type or scaling_per_output_channel")


class SolveScalingStatsInputViewShapeImplFromEnum(ExtendedInjector):
Expand Down
12 changes: 12 additions & 0 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from brevitas.core.bit_width import *
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.function_wrapper import TensorClampSte
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl
from brevitas.core.scaling import *
from brevitas.core.scaling import ScalingImplType
from brevitas.core.scaling import ScalingPerOutputType
Expand Down Expand Up @@ -139,3 +141,13 @@ def expanded_scaling_shape(module, group_size=None):
def group_dim(module, group_size=None):
if group_size is not None:
return 1


class SolveInputViewImpl(ExtendedInjector):

@value
def input_view_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK
else:
return Identity
4 changes: 3 additions & 1 deletion src/brevitas/quant/solver/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.proxy import WeightQuantProxyFromInjector
from brevitas.quant.solver.common import *
from brevitas.quant.solver.parameter import *
from brevitas.quant.solver.parameter import SolveInputViewImpl

__all__ = [
'SolveWeightTensorQuantFromEnum',
Expand Down Expand Up @@ -103,7 +104,8 @@ class WeightQuantSolver(SolveStatsReduceDimFromEnum,
SolveParameterScalingShape,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveWeightTensorQuantFromEnum,
SolveDtypeDeviceFromTrackedParameterList):
SolveDtypeDeviceFromTrackedParameterList,
SolveInputViewImpl):
"""
Translate enum and shape directives to weight-specific quantization core modules.
It should be placed last in the list of classes a quantizer inherits from,
Expand Down
Loading

0 comments on commit 89fca2f

Please sign in to comment.