From 9048ecb21909eda5e435e2310275118051a1fccc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 8 Oct 2024 14:29:24 +0100 Subject: [PATCH] Feat (quant): decoupled PerChannel/PerTensor quantization (#1025) --------- Co-authored-by: Ian Colbert --- src/brevitas/quant/base.py | 93 +++++++++++++++---- tests/brevitas/export/quant_module_fixture.py | 25 ++++- tests/brevitas/export/test_qonnx_export.py | 37 +++++++- tests/brevitas/nn/nn_quantizers_fixture.py | 20 +++- tests/brevitas/nn/test_a2q.py | 13 ++- 5 files changed, 166 insertions(+), 22 deletions(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index e1d118239..a77a10283 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -337,7 +337,62 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, scaling_per_output_type = ScalingPerOutputType.CHANNEL -class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, +class PerChannelL2Norm(ExtendedInjector): + stats_reduce_dim = SCALING_STATS_REDUCE_DIM + normalize_stats_impl = L2Norm + + +class PerChannelL1Norm(ExtendedInjector): + stats_reduce_dim = SCALING_STATS_REDUCE_DIM + normalize_stats_impl = L1Norm + + +class PerChannelPreNorm(ExtendedInjector): + pre_scaling_impl = ParameterPreScalingWeightNorm + scaling_stats_input_view_shape_impl = OverOutputChannelView + scaling_impl = (this << 1).scaling_impl + normalize_stats_impl = (this << 1).normalize_stats_impl + tracked_parameter_list = (this << 1).tracked_parameter_list + pre_scaling_shape = (this << 1).pre_scaling_shape + permute_dims = (this << 1).permute_dims + + +class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm): + + pre_scaling_impl = AccumulatorAwareParameterPreScaling + accumulator_bit_width = (this << 1).accumulator_bit_width + accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl + + +class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm): + + pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling + pre_zero_point_impl = PreZeroCenterZeroPoint + pre_zero_point_shape = this.pre_scaling_shape # TODO: decouple zero_point from scaling + pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + stats_reduce_dim = SCALING_STATS_REDUCE_DIM + scaling_shape = (this << 1).scaling_shape + + +class SolvePostScaleGranularity(ExtendedInjector): + + @value + def scaling_stats_input_view_shape_impl(scaling_per_output_type): + if scaling_per_output_type == ScalingPerOutputType.TENSOR: + return StatsInputViewShapeImpl.OVER_TENSOR + elif scaling_per_output_type == ScalingPerOutputType.CHANNEL: + return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS + + @value + def stats_reduce_dim(scaling_per_output_type): + if scaling_per_output_type == ScalingPerOutputType.TENSOR: + return None + elif scaling_per_output_type == ScalingPerOutputType.CHANNEL: + return SCALING_STATS_REDUCE_DIM + + +class WeightNormPerChannelFloatDecoupled(SolvePostScaleGranularity, + SolveStatsReduceDimFromEnum, SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, SolveParameterScalingShape, @@ -361,6 +416,8 @@ def scaling_init(scaling_init_impl, bit_width): scales = scaling_init_impl.parameter_list_stats() / (pow(2., bit_width - 1.) - 1.) return scales + per_channel_pre_norm = PerChannelPreNorm + proxy_class = DecoupledWeightQuantProxyFromInjector tensor_quant = DecoupledRescalingIntQuant decoupled_int_quant = DecoupledIntQuant @@ -369,22 +426,23 @@ def scaling_init(scaling_init_impl, bit_width): scaling_init_impl = StatsFromParameterScaling restrict_scaling_impl = LogFloatRestrictValue scaling_stats_impl = AbsMax - pre_scaling_impl = ParameterPreScalingWeightNorm restrict_pre_scaling_impl = LogFloatRestrictValue - normalize_stats_impl = L2Norm + normalize_stats_impl = PerChannelL2Norm.normalize_stats_impl scaling_per_output_type = ScalingPerOutputType.CHANNEL - pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape + pre_scaling_shape = this.scaling_per_output_channel_shape int_scaling_impl = SingleArgStatelessBuffer(1.) zero_point_impl = ZeroZeroPoint pre_zero_point_impl = ZeroZeroPoint bit_width_impl = BitWidthConst narrow_range = True signed = True - scaling_stats_input_view_shape_impl = OverOutputChannelView - stats_reduce_dim = SCALING_STATS_REDUCE_DIM scaling_min_val = 1e-10 pre_scaling_min_val = 1e-10 + @value + def pre_scaling_impl(): + return this.per_channel_pre_norm.pre_scaling_impl + class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled): """Experimental accumulator-aware weight quantizer based on `Quantized Neural Networks @@ -403,16 +461,16 @@ class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled): details on the arithmetic, see `AccumulatorAwareParameterPreScalingWeightNorm`. For further details on accumulator-aware quantization (A2Q) technique, see the referenced paper.""" - @value - def accumulator_bit_width_impl(accumulator_bit_width): - return BitWidthStatefulConst(accumulator_bit_width) - proxy_class = DecoupledWeightQuantWithInputProxyFromInjector tensor_quant = DecoupledRescalingIntQuantWithInput - pre_scaling_impl = AccumulatorAwareParameterPreScaling - accumulator_bit_width = 32 # default maximum accumulator width is 32 bits - normalize_stats_impl = L1Norm # required to align with derivations in paper + per_channel_pre_norm = AccumulatorAwarePerChannelPreNorm + normalize_stats_impl = PerChannelL1Norm.normalize_stats_impl # required to align with derivations in paper float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints + accumulator_bit_width = 32 # default maximum accumulator width is 32 bits + + @value + def accumulator_bit_width_impl(accumulator_bit_width): + return BitWidthStatefulConst(accumulator_bit_width) class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant): @@ -423,10 +481,11 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant): (1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`) (2) a more relaxed l1-norm bound that is derived in the referenced paper """ - pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling - pre_zero_point_impl = PreZeroCenterZeroPoint - pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling - pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + per_channel_pre_norm = AccumulatorAwareZeroCenterPerChannelPreNorm + + @value + def pre_zero_point_impl(): + return this.per_channel_pre_norm.pre_zero_point_impl class MSESymmetricScaleSubInjector(ExtendedInjector): diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 31524729f..06d62da9c 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -7,6 +7,7 @@ import torch from torch import nn +from brevitas.inject.enum import ScalingPerOutputType from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d from brevitas.nn import QuantConv3d @@ -20,6 +21,7 @@ from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling from brevitas.quant.scaled_int import Int8WeightPerChannelFloat @@ -39,6 +41,17 @@ KERNEL_SIZE = 3 TOLERANCE = 1 + +class Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat( + Int8AccumulatorAwareZeroCenterWeightQuant): + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +A2Q_QUANTIZERS = { + 'a2q_per_channel_float': (Int8AccumulatorAwareWeightQuant, Int8ActPerTensorFloat), + 'a2q_plus_per_tensor_float': + (Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat, Int8ActPerTensorFloat)} + QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat), @@ -46,14 +59,15 @@ 'asymmetric_per_channel_float': (ShiftedUint8WeightPerChannelFloat, ShiftedUint8ActPerTensorFloat), 'symmetric_per_channel_float': (Int8WeightPerChannelFloat, Int8ActPerTensorFloat), - 'a2q': (Int8AccumulatorAwareWeightQuant, Int8ActPerTensorFloat), 'symmetric_per_tensor_fixed_point': (Int8WeightPerTensorFixedPoint, Int8ActPerTensorFixedPoint), 'symmetric_per_channel_fixed_point': - (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)} + (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint), + **A2Q_QUANTIZERS} BIAS_QUANTIZERS = { 'bias_external_scale': (Int32Bias,), 'bias_internal_scale': (Int8BiasPerTensorFloatInternalScaling,)} + QUANT_WBIOL_IMPL = [ QuantLinear, QuantConv1d, @@ -62,6 +76,7 @@ QuantConvTranspose1d, QuantConvTranspose2d, QuantConvTranspose3d,] + BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] @@ -102,6 +117,12 @@ def weight_act_quantizers(quantizers): return quantizers +@fixture +@parametrize('quantizers', A2Q_QUANTIZERS.items(), ids=list(A2Q_QUANTIZERS.keys())) +def a2q_weight_act_quantizers(quantizers): + return quantizers + + @fixture @parametrize('quantizer', BIAS_QUANTIZERS.items(), ids=list(BIAS_QUANTIZERS.keys())) def bias_quantizer(quantizer): diff --git a/tests/brevitas/export/test_qonnx_export.py b/tests/brevitas/export/test_qonnx_export.py index 70329db66..59a4d6214 100644 --- a/tests/brevitas/export/test_qonnx_export.py +++ b/tests/brevitas/export/test_qonnx_export.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import os + import torch from brevitas.export import enable_debug @@ -9,7 +11,6 @@ from brevitas.nn import QuantConv2d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear -from brevitas.nn import QuantReLU from brevitas.nn import TruncAvgPool2d from brevitas.quant.scaled_int import Int4WeightPerTensorFloatDecoupled from brevitas.quant.scaled_int import Int8ActPerTensorFloat @@ -17,6 +18,8 @@ from brevitas_examples import imagenet_classification from tests.marker import jit_disabled_for_export +from .quant_module_fixture import * + OUT_CH = 50 IN_CH = 40 TOLERANCE = 1.1 @@ -48,6 +51,7 @@ def forward(self, x): model(inp) # collect scale factors model.eval() export_qonnx(model, inp, export_path='generic_quant_linear.onnx') + os.remove('generic_quant_linear.onnx') @jit_disabled_for_export() @@ -79,6 +83,37 @@ def forward(self, x): export_qonnx(model, inp, export_path='generic_decoupled_quant_linear.onnx') +@jit_disabled_for_export() +def test_a2q_quant_linear_export(a2q_weight_act_quantizers): + IN_SIZE = (2, IN_CH) + + _, (weight_quant, io_quant) = a2q_weight_act_quantizers + + class Model(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = QuantLinear( + out_features=OUT_CH, + in_features=IN_CH, + bias=True, + input_quant=io_quant, + output_quant=io_quant, + weight_quant=weight_quant, + bias_quant=Int16Bias, + return_quant_tensor=False) + self.linear.weight.data.uniform_(-0.1, 0.1) + + def forward(self, x): + return self.linear(x) + + inp = torch.randn(IN_SIZE) + model = Model() + model(inp) # collect scale factors + model.eval() + export_qonnx(model, inp, export_path='a2q_quant_linear.onnx') + + @jit_disabled_for_export() def test_generic_quant_conv_export(): IN_SIZE = (2, IN_CH, IN_CH, IN_CH) diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index a6b1c05af..40893d7d7 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -11,6 +11,7 @@ from brevitas import torch_version import brevitas.config as config +from brevitas.inject.enum import ScalingPerOutputType from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d from brevitas.nn import QuantConv3d @@ -48,6 +49,20 @@ EMBED_DIM = 9 NUM_HEADS = 3 + +class Int8WeightNormL2PerChannelPerTensorFixedPoint(Int8WeightNormL2PerChannelFixedPoint): + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +class Int8AccumulatorAwareWeightQuantPerTensorFloat(Int8AccumulatorAwareWeightQuant): + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +class Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat( + Int8AccumulatorAwareZeroCenterWeightQuant): + scaling_per_output_type = ScalingPerOutputType.TENSOR + + LSTM_WEIGHT_QUANTIZER = { 'None': None, 'quant_sym': Int8WeightPerTensorFloat, @@ -55,13 +70,16 @@ A2Q_WBIOL_WEIGHT_QUANTIZER = { 'quant_a2q': Int8AccumulatorAwareWeightQuant, - 'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant} + 'quant_a2q_per_tensor': Int8AccumulatorAwareWeightQuantPerTensorFloat, + 'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant, + 'quant_a2q_plus_per_tensor': Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat} WBIOL_WEIGHT_QUANTIZER = { 'None': None, 'quant_sym': Int8WeightPerTensorFloat, 'quant_asym': ShiftedUint8WeightPerTensorFloat, 'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint, + 'quant_decoupled_per_tensor': Int8WeightNormL2PerChannelPerTensorFixedPoint, 'quant_mx': MXInt8Weight, 'quant_float': Fp8e4m3WeightPerTensorFloat, **A2Q_WBIOL_WEIGHT_QUANTIZER} diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index fa8dd701a..eb4fac1ed 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -63,7 +63,11 @@ def calc_a2q_plus_acc_bit_width( return min_bit_width -calc_fnc = {"quant_a2q": calc_a2q_acc_bit_width, "quant_a2q_plus": calc_a2q_plus_acc_bit_width} +calc_fnc = { + "quant_a2q": calc_a2q_acc_bit_width, + "quant_a2q_per_tensor": calc_a2q_acc_bit_width, + "quant_a2q_plus": calc_a2q_plus_acc_bit_width, + "quant_a2q_plus_per_tensor": calc_a2q_plus_acc_bit_width} @pytest_cases.parametrize_with_cases('model_input', cases=case_model_a2q) @@ -94,6 +98,13 @@ def test_quant_wbiol_a2q(model_input, current_cases): # the tensor quantizer requires a QuantTensor with specified bit-width and sign quant_weight = model.conv.quant_weight(quant_input) + + # test that the scaling factor is per-tensor or per-channel + if kwargs['weight_quant'].endswith('per_tensor'): + assert quant_weight.scale.numel() == 1 + else: + assert quant_weight.scale.numel() == model.conv.out_channels + quant_weight = quant_weight.int().float() if kwargs['model_type'] == 'QuantLinear': # shape = (out_features, in_features) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=1)