From 89fca2f56b57650e77b8e400f9e579c065186ccd Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 24 Aug 2024 12:17:22 +0200 Subject: [PATCH] Fix (mx): input view during quantization (#1005) --- notebooks/Brevitas_TVMCon2021.ipynb | 8 ++++++-- notebooks/minifloat_mx_tutorial.ipynb | 6 +++--- src/brevitas/core/function_wrapper/shape.py | 20 +++++++++++++++++++ src/brevitas/core/quant/float.py | 4 +++- src/brevitas/core/quant/int_base.py | 6 ++++++ src/brevitas/core/scaling/runtime.py | 14 +++---------- src/brevitas/export/onnx/qonnx/function.py | 2 ++ .../proxy/groupwise_float_parameter_quant.py | 6 ------ .../proxy/groupwise_float_runtime_quant.py | 1 - .../proxy/groupwise_int_parameter_quant.py | 1 - src/brevitas/quant/base.py | 10 ++++++++-- src/brevitas/quant/solver/act.py | 15 +++++++++++++- src/brevitas/quant/solver/bias.py | 4 +++- src/brevitas/quant/solver/common.py | 2 -- src/brevitas/quant/solver/parameter.py | 12 +++++++++++ src/brevitas/quant/solver/weight.py | 4 +++- .../groupwise_float_quant_tensor.py | 9 ++++++--- .../groupwise_int_quant_tensor.py | 8 +++++--- tests/brevitas/core/test_float_quant.py | 9 +++++++++ tests/brevitas/core/test_int_quant.py | 6 ++++-- 20 files changed, 107 insertions(+), 40 deletions(-) diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index 20ce30701..7f5846e09 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -1659,11 +1659,13 @@ "from brevitas.core.bit_width import BitWidthConst\n", "from brevitas.core.quant import IntQuant, RescalingIntQuant\n", "from brevitas.core.zero_point import ZeroZeroPoint\n", + "from brevitas.core.function_wrapper.misc import Identity\n", "\n", "tensor_quant = RescalingIntQuant(\n", " int_quant=IntQuant(\n", " float_to_int_impl=RoundSte(),\n", " tensor_clamp_impl=TensorClamp(),\n", + " input_view_impl=Identity,\n", " signed=False,\n", " narrow_range=False),\n", " zero_point_impl=ZeroZeroPoint(),\n", @@ -1767,6 +1769,7 @@ "from brevitas.inject import value\n", "from brevitas.proxy import WeightQuantProxyFromInjector\n", "from brevitas.core.scaling import ParameterScaling\n", + "from brevitas.core.function_wrapper.misc import Identity\n", "\n", "class Int8ActPerTensorFloatParameterFromScratch(ExtendedInjector):\n", " \n", @@ -1784,11 +1787,12 @@ " int_scaling_impl = IntScaling\n", " scaling_impl = ParameterScaling\n", " restrict_scaling_impl = FloatRestrictValue\n", + " input_view_impl = Identity\n", " scaling_shape = ()\n", " bit_width = 8\n", " narrow_range = True\n", " signed = True\n", - " \n", + "\n", "quant_linear = QuantLinear(2, 4, weight_quant=Int8ActPerTensorFloatParameterFromScratch, bias=False)" ] }, @@ -1936,7 +1940,7 @@ "torch.manual_seed(0)\n", "\n", "from brevitas.export import export_qonnx\n", - "from brevitas.quant import Int8WeightPerTensorFloat, Int8ActPerTensorFloat, Int16Bias\n", + "from brevitas.quant import Int8ActPerTensorFloat, Int16Bias\n", "\n", "float_inp = torch.randn(1, 2, 5)\n", "\n", diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb index e764fd05c..60f00fcd4 100644 --- a/notebooks/minifloat_mx_tutorial.ipynb +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -233,12 +233,12 @@ "import brevitas.nn as qnn\n", "import torch\n", "\n", - "class MXFloat8Weight(MXInt8Weight):\n", + "class MXInt8Weight(MXInt8Weight):\n", " # The group dimension for the weights it is automatically identified based on the layer type\n", " # If a new layer type is used, it can be manually specified\n", " bit_width = 8\n", "\n", - "class MXFloat8Act(MXInt8Act):\n", + "class MXInt8Act(MXInt8Act):\n", " # It is necessary to specify the group dimension for the activation quantization\n", " group_dim = 1\n", " bit_width = 8\n", @@ -246,7 +246,7 @@ "class MXModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", - " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", + " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXInt8Weight, input_quant=MXInt8Act)\n", " \n", " def forward(self, x):\n", " return self.conv(x)\n", diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index cdef81b3e..1bb77476e 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -171,6 +171,25 @@ def forward(self, x: torch.Tensor): return y +class DynamicOverSubChannelBlockView(brevitas.jit.ScriptModule): + __constants__ = ['group_size', 'group_dim'] + + def __init__(self, group_size, group_dim) -> None: + super(DynamicOverSubChannelBlockView, self).__init__() + self.group_size = group_size + self.group_dim = group_dim + + @brevitas.jit.script_method + def forward(self, x): + tensor_shape = x.shape + tensor_shape_list = list(tensor_shape) + tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) + block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list.insert(block_dim, self.group_size) + x = x.view(tensor_shape_list) + return x + + class StatsInputViewShapeImpl(object): """ Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. @@ -182,3 +201,4 @@ class StatsInputViewShapeImpl(object): OVER_BATCH_OVER_OUTPUT_CHANNELS = OverBatchOverOutputChannelView OVER_OUTPUT_FEATURES = OverOutputFeaturesView OVER_SUBCHANNEL_BLOCK = OverSubChannelBlockView + DYNAMIC_OVER_SUBCHANNEL_BLOCK = DynamicOverSubChannelBlockView diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 195d42a96..f4fd79f1a 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -24,6 +24,7 @@ def __init__( mantissa_bit_width: int, exponent_bias: int, float_clamp_impl: nn.Module, + input_view_impl: nn.Module, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, float_to_int_impl: nn.Module = RoundSte(), @@ -52,6 +53,7 @@ def __init__( if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) + self.input_view_impl = input_view_impl # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl @@ -71,7 +73,7 @@ def quantize(self, x: torch.Tensor): float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) scale = scale / float_scaling_impl_value - + x = self.input_view_impl(x) scaled_x = x / scale internal_scale = float_internal_scale( scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps) diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index 7a7a0f828..338e5a433 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -51,6 +51,7 @@ def __init__( self, narrow_range: bool, signed: bool, + input_view_impl: Module, float_to_int_impl: Module = RoundSte(), tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): @@ -60,9 +61,11 @@ def __init__( self.signed = signed self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps) + self.input_view_impl = input_view_impl @brevitas.jit.script_method def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: + x = self.input_view_impl(x) y = x / scale y = y + zero_point min_int_val = self.min_int(bit_width) @@ -124,6 +127,7 @@ def __init__( self, narrow_range: bool, signed: bool, + input_view_impl: Module, float_to_int_impl: Module = RoundSte(), tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): @@ -133,11 +137,13 @@ def __init__( self.signed = signed self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps) + self.input_view_impl = input_view_impl @brevitas.jit.script_method def to_int( self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: + x = self.input_view_impl(x) y = x / pre_scale y = y + pre_zero_point min_int_val = self.min_int(bit_width) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 23707344f..e4333186d 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -166,6 +166,7 @@ def __init__( self, group_size: int, group_dim: int, + input_view_impl: torch.nn.Module, scaling_stats_impl: torch.nn.Module, scaling_min_val: Optional[float], restrict_scaling_impl: Optional[torch.nn.Module]) -> None: @@ -174,21 +175,12 @@ def __init__( self.group_dim = group_dim self.scaling_stats_impl = scaling_stats_impl self.scaling_min_val = scaling_min_val + self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) - @brevitas.jit.script_method - def group_scaling_reshape(self, stats_input): - tensor_shape = stats_input.shape - tensor_shape_list = list(tensor_shape) - tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) - block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 - tensor_shape_list.insert(block_dim, self.group_size) - stats_input = stats_input.view(tensor_shape_list) - return stats_input - @brevitas.jit.script_method def forward(self, stats_input) -> torch.Tensor: - stats_input_reshaped = self.group_scaling_reshape(stats_input) + stats_input_reshaped = self.input_view_impl(stats_input) out = self.scaling_stats_impl(stats_input_reshaped) # Scaling min val out = self.restrict_clamp_scaling(out) diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index d410ee31e..3e7faad0e 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -7,6 +7,7 @@ from brevitas.core.bit_width import BitWidthConst from brevitas.core.function_wrapper.clamp import TensorClamp +from brevitas.core.function_wrapper.misc import Identity from brevitas.core.quant import IntQuant from brevitas.core.quant import TruncIntQuant from brevitas.function import binary_sign @@ -51,6 +52,7 @@ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding quant = IntQuant( float_to_int_impl=float_to_int_impl(), tensor_clamp_impl=TensorClamp(), + input_view_impl=Identity(), #TODO: Update this when QONNX support Groupwise export narrow_range=narrow_range, signed=signed) y = quant(scale, zero_point, bit_width, x) diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index cd38d9906..d08033f8e 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -9,11 +9,6 @@ class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # TODO: Is this always generated? - self.view_impl = self.quant_injector.scaling_stats_input_view_shape_impl - @property def group_dim(self): return self.quant_injector.group_dim @@ -25,7 +20,6 @@ def group_size(self): def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant - x = self.view_impl(x) out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) return GroupwiseFloatQuantTensor( out, diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 4ab182d20..b2aad4729 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -24,7 +24,6 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat y = x if isinstance(y, QuantTensor): y = y.value - if self.export_mode: y = self.fused_activation_quant_proxy.activation_impl(y) y = self.export_handler(y) diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 035ee9729..35892daeb 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -25,7 +25,6 @@ def group_size(self): def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant - x = self.view_impl(x) out, scale, zero_point, bit_width = impl(x) return GroupwiseIntQuantTensor( out, diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 18351a05b..7b6fe409e 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -7,6 +7,7 @@ from brevitas.core.bit_width import BitWidthConst from brevitas.core.bit_width import BitWidthStatefulConst +from brevitas.core.function_wrapper import Identity from brevitas.core.function_wrapper import OverOutputChannelView from brevitas.core.function_wrapper import RoundToZeroSte from brevitas.core.function_wrapper import TensorClamp @@ -53,6 +54,7 @@ from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector from brevitas.quant.solver.common import SolveStatsReduceDimFromEnum +from brevitas.quant.solver.parameter import SolveInputViewImpl from brevitas.quant.solver.parameter import SolveParameterScalingShape from brevitas.quant.solver.weight import SolveWeightScalingPerOutputChannelShapeFromModule from brevitas.quant.solver.weight import SolveWeightScalingStatsInputDimsFromModule @@ -293,6 +295,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM stats_reduce_dim = SCALING_STATS_REDUCE_DIM restrict_scaling_impl = FloatRestrictValue scaling_shape = SCALAR_SHAPE + scaling_per_output_type = ScalingPerOutputType.TENSOR + input_view_impl = Identity scaling_impl = ParameterFromStatsFromParameterScaling int_scaling_impl = IntScaling zero_point_impl = ZeroZeroPoint @@ -305,7 +309,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, - SolveParameterScalingShape): + SolveParameterScalingShape, + SolveInputViewImpl): """ Experimental narrow per-channel signed int weight quantizer fragment with decoupled Linf normalization and learned scaling. @@ -333,7 +338,8 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, - SolveParameterScalingShape): + SolveParameterScalingShape, + SolveInputViewImpl): """Experimental narrow per-channel weight normalization-based signed integer quantizer based on `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig. diff --git a/src/brevitas/quant/solver/act.py b/src/brevitas/quant/solver/act.py index 3149e75b9..345239089 100644 --- a/src/brevitas/quant/solver/act.py +++ b/src/brevitas/quant/solver/act.py @@ -5,6 +5,8 @@ from torch import nn from torch import Tensor +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from brevitas.core.quant import ClampedBinaryQuant from brevitas.core.quant import RescalingIntQuant from brevitas.core.quant import TernaryQuant @@ -128,6 +130,16 @@ def update_state_dict_impl(scaling_impl_type): return None +class SolveInputViewImpl(ExtendedInjector): + + @value + def input_view_impl(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.GROUP: + return StatsInputViewShapeImpl.DYNAMIC_OVER_SUBCHANNEL_BLOCK + else: + return Identity + + class ActQuantSolver(SolveActTensorQuantFromEnum, SolveActScalingImplFromEnum, SolveIntScalingImplFromEnum, @@ -140,7 +152,8 @@ class ActQuantSolver(SolveActTensorQuantFromEnum, SolveActScalingShape, SolveScalingStatsInputViewShapeImplFromEnum, SolveActScalingPerOutputChannelShape, - SolveUpdateStateDictImplFromEnum): + SolveUpdateStateDictImplFromEnum, + SolveInputViewImpl): """ Translate enum directives to activation-specific quantization core modules. It should be placed last in the list of classes a quantizer inherits from, diff --git a/src/brevitas/quant/solver/bias.py b/src/brevitas/quant/solver/bias.py index 33a55f55c..eb840541e 100644 --- a/src/brevitas/quant/solver/bias.py +++ b/src/brevitas/quant/solver/bias.py @@ -9,6 +9,7 @@ from brevitas.proxy import BiasQuantProxyFromInjector from brevitas.quant.solver.common import * from brevitas.quant.solver.parameter import * +from brevitas.quant.solver.parameter import SolveInputViewImpl __all__ = [ 'BiasQuantSolver', @@ -65,7 +66,8 @@ class BiasQuantSolver(SolveScalingStatsInputViewShapeImplFromEnum, SolveBiasScalingPerOutputChannelShapeFromModule, SolveBiasScalingStatsInputConcatDimFromModule, SolveBiasTensorQuantFromEnum, - SolveDtypeDeviceFromTrackedParameterList): + SolveDtypeDeviceFromTrackedParameterList, + SolveInputViewImpl): """ Translate enum directives to bias-specific quantization core modules. It should be placed last in the list of classes a quantizer inherits from, diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 61eccc90b..2847275e8 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -195,8 +195,6 @@ def scaling_per_output(scaling_per_output_type=None, scaling_per_output_channel= return ScalingPerOutputType.CHANNEL if scaling_per_output_channel else ScalingPerOutputType.TENSOR elif scaling_per_output_type is not None: return scaling_per_output_type - else: - raise RuntimeError("Specify scaling_per_output_type or scaling_per_output_channel") class SolveScalingStatsInputViewShapeImplFromEnum(ExtendedInjector): diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index d8c655efa..97137c567 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -12,6 +12,8 @@ from brevitas.core.bit_width import * from brevitas.core.function_wrapper import TensorClamp from brevitas.core.function_wrapper import TensorClampSte +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from brevitas.core.scaling import * from brevitas.core.scaling import ScalingImplType from brevitas.core.scaling import ScalingPerOutputType @@ -139,3 +141,13 @@ def expanded_scaling_shape(module, group_size=None): def group_dim(module, group_size=None): if group_size is not None: return 1 + + +class SolveInputViewImpl(ExtendedInjector): + + @value + def input_view_impl(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.GROUP: + return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK + else: + return Identity diff --git a/src/brevitas/quant/solver/weight.py b/src/brevitas/quant/solver/weight.py index 097f65443..7f63fe17e 100644 --- a/src/brevitas/quant/solver/weight.py +++ b/src/brevitas/quant/solver/weight.py @@ -10,6 +10,7 @@ from brevitas.proxy import WeightQuantProxyFromInjector from brevitas.quant.solver.common import * from brevitas.quant.solver.parameter import * +from brevitas.quant.solver.parameter import SolveInputViewImpl __all__ = [ 'SolveWeightTensorQuantFromEnum', @@ -103,7 +104,8 @@ class WeightQuantSolver(SolveStatsReduceDimFromEnum, SolveParameterScalingShape, SolveWeightScalingPerOutputChannelShapeFromModule, SolveWeightTensorQuantFromEnum, - SolveDtypeDeviceFromTrackedParameterList): + SolveDtypeDeviceFromTrackedParameterList, + SolveInputViewImpl): """ Translate enum and shape directives to weight-specific quantization core modules. It should be placed last in the list of classes a quantizer inherits from, diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 7d73bf7de..fa91bdca1 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -90,13 +90,15 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def expand(self): curr_shape = self.value_.shape - new_value = self.value_.flatten(self.group_dim, self.group_dim + 1) + start_dim = self.group_dim if self.group_dim != -1 else -2 + new_value = self.value_.flatten(start_dim, start_dim + 1) + new_value = self.value_.flatten(start_dim, start_dim + 1) if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: new_scale = self.scale_ if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: new_zp = self.zero_point_ @@ -104,6 +106,7 @@ def expand(self): @staticmethod def from_expanded(value, group_size, group_dim, compress=False): + group_dim = group_dim if group_dim != -1 else -2 size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 976e86130..082ec1234 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -59,13 +59,14 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def expand(self): curr_shape = self.value_.shape - new_value = self.value_.flatten(self.group_dim, self.group_dim + 1) + start_dim = self.group_dim if self.group_dim != -1 else -2 + new_value = self.value_.flatten(start_dim, start_dim + 1) if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: new_scale = self.scale_ if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: new_zp = self.zero_point_ @@ -73,6 +74,7 @@ def expand(self): @staticmethod def from_expanded(value, group_size, group_dim, compress=False): + group_dim = group_dim if group_dim != -1 else -2 size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 2d4c829f0..16b8a4b5f 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -9,6 +9,7 @@ from brevitas.core.function_wrapper import FloatClamp from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp +from brevitas.core.function_wrapper.misc import Identity from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling @@ -32,6 +33,7 @@ def test_float_quant_defaults(minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), float_clamp_impl=None) else: # init FloatClamp @@ -48,6 +50,7 @@ def test_float_quant_defaults(minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=float_clamp) assert isinstance(float_quant.float_to_int_impl, RoundSte) @@ -73,6 +76,7 @@ def test_float_to_quant_float(inp, minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=None) else: @@ -90,6 +94,7 @@ def test_float_to_quant_float(inp, minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=float_clamp) expected_out, *_ = float_quant(inp) @@ -115,6 +120,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=None) @@ -132,6 +138,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) @@ -162,6 +169,7 @@ def test_inner_scale(inp, minifloat_format, scale): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=None) @@ -179,6 +187,7 @@ def test_inner_scale(inp, minifloat_format, scale): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) diff --git a/tests/brevitas/core/test_int_quant.py b/tests/brevitas/core/test_int_quant.py index 312795235..5e106dc4c 100644 --- a/tests/brevitas/core/test_int_quant.py +++ b/tests/brevitas/core/test_int_quant.py @@ -5,6 +5,7 @@ import mock import torch +from brevitas.core.function_wrapper import Identity from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant import * @@ -30,6 +31,7 @@ def test_int_quant_to_int_called_with( int_quant = IntQuant( narrow_range=narrow_range, signed=signed, + input_view_impl=Identity(), float_to_int_impl=float_to_int_impl, tensor_clamp_impl=tensor_clamp_impl) bit_width = torch.tensor(bit_width_init) @@ -39,7 +41,7 @@ def test_int_quant_to_int_called_with( output, min_val=int_quant.min_int(bit_width), max_val=int_quant.max_int(bit_width)) def test_int_quant_defaults(self, narrow_range, signed): - int_quant = IntQuant(narrow_range=narrow_range, signed=signed) + int_quant = IntQuant(narrow_range=narrow_range, signed=signed, input_view_impl=Identity()) assert isinstance(int_quant.float_to_int_impl, RoundSte) assert isinstance(int_quant.tensor_clamp_impl, TensorClamp) @@ -51,7 +53,7 @@ def test_int_quant_arange( zero_point_init, bit_width_init, arange_int_tensor): - int_quant = IntQuant(narrow_range=narrow_range, signed=signed) + int_quant = IntQuant(narrow_range=narrow_range, signed=signed, input_view_impl=Identity()) zero_point = torch.tensor(zero_point_init).float() bit_width = torch.tensor(bit_width_init).float() scale = torch.tensor(standalone_scaling_init).float()