diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 5198444b1..4917b859a 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -229,7 +229,7 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: - return self.stats_scaling_impl(stats) + return self.stats_scaling_impl(stats, threshold) stats = self.restrict_inplace_preprocess(stats) threshold = self.restrict_inplace_preprocess(threshold) inplace_tensor_mul(self.value.detach(), stats) diff --git a/tests/brevitas_examples/test_quantize_model.py b/tests/brevitas_examples/test_quantize_model.py index 8ab34d5db..ea537310a 100644 --- a/tests/brevitas_examples/test_quantize_model.py +++ b/tests/brevitas_examples/test_quantize_model.py @@ -654,6 +654,69 @@ def get_qmse( assert torch.isclose(diff_mse, orig_mse) or (diff_mse > orig_mse) +@pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"]) +@jit_disabled_for_local_loss() +def test_layerwise_stats_vs_mse(simple_model, quant_granularity): + """ + We test layerwise quantization, with the weight and activation quantization `mse` parameter + methods. + + We test: + - Recostruction error of MSE should be smaller or equal to stats + """ + weight_bit_width = 8 + act_bit_width = 8 + bias_bit_width = 32 + quant_model_mse = quantize_model( + model=deepcopy(simple_model), + backend='layerwise', + weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + bias_bit_width=bias_bit_width if bias_bit_width > 0 else None, + weight_quant_granularity=quant_granularity, + act_quant_type='asym', + act_quant_percentile=99.9, # Unused + scale_factor_type='float_scale', + quant_format='int', + weight_param_method='mse', + act_param_method='mse') + + quant_model_stats = quantize_model( + model=deepcopy(simple_model), + backend='layerwise', + weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + bias_bit_width=bias_bit_width if bias_bit_width > 0 else None, + weight_quant_granularity=quant_granularity, + act_quant_type='asym', + act_quant_percentile=99.9, # Unused + scale_factor_type='float_scale', + quant_format='int', + weight_param_method='stats', + act_param_method='mse') + + # We create an input with values linearly scaled between 0 and 1. + input = torch.arange(0, 1, step=1 / (10 * IMAGE_DIM ** 2)) + input = input.view(1, 10, IMAGE_DIM, IMAGE_DIM).float() + with torch.no_grad(): + with calibration_mode(quant_model_mse): + quant_model_mse(input) + quant_model_mse.eval() + with torch.no_grad(): + with calibration_mode(quant_model_stats): + quant_model_stats(input) + quant_model_stats.eval() + weight = simple_model.layers.get_submodule('0').weight + first_conv_layer_mse = quant_model_mse.layers.get_submodule('0') + first_conv_layer_stats = quant_model_stats.layers.get_submodule('0') + + l2_stats = ((weight - first_conv_layer_stats.quant_weight().value) ** 2).sum() + l2_mse = ((weight - first_conv_layer_mse.quant_weight().value) ** 2).sum() + + # Recostruction error of MSE should be smaller or equal to stats + assert l2_mse - l2_stats <= torch.tensor(1e-5) + + @pytest.mark.parametrize("weight_bit_width", [2, 5, 8, 16]) @pytest.mark.parametrize("act_bit_width", [2, 5, 8]) @pytest.mark.parametrize("bias_bit_width", [16, 32])