From 417dfcb6f68bd64fab1d8f6b3ae0097298c414c7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Aug 2024 07:47:36 +0100 Subject: [PATCH] Fix last 2 tests --- src/brevitas/quant/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index abcf6d4f0..7b6fe409e 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -7,6 +7,7 @@ from brevitas.core.bit_width import BitWidthConst from brevitas.core.bit_width import BitWidthStatefulConst +from brevitas.core.function_wrapper import Identity from brevitas.core.function_wrapper import OverOutputChannelView from brevitas.core.function_wrapper import RoundToZeroSte from brevitas.core.function_wrapper import TensorClamp @@ -294,6 +295,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM stats_reduce_dim = SCALING_STATS_REDUCE_DIM restrict_scaling_impl = FloatRestrictValue scaling_shape = SCALAR_SHAPE + scaling_per_output_type = ScalingPerOutputType.TENSOR + input_view_impl = Identity scaling_impl = ParameterFromStatsFromParameterScaling int_scaling_impl = IntScaling zero_point_impl = ZeroZeroPoint @@ -306,7 +309,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, - SolveParameterScalingShape): + SolveParameterScalingShape, + SolveInputViewImpl): """ Experimental narrow per-channel signed int weight quantizer fragment with decoupled Linf normalization and learned scaling.