From da29d17f98296eea7261516e287a9cace390cda2 Mon Sep 17 00:00:00 2001 From: jfrery Date: Mon, 3 Jun 2024 15:34:39 +0200 Subject: [PATCH] chore: fix batchn norm per channel with clipping --- src/concrete/ml/quantization/quantized_ops.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/concrete/ml/quantization/quantized_ops.py b/src/concrete/ml/quantization/quantized_ops.py index a275e04a5..c28f19eb0 100644 --- a/src/concrete/ml/quantization/quantized_ops.py +++ b/src/concrete/ml/quantization/quantized_ops.py @@ -1612,6 +1612,49 @@ class QuantizedBatchNormalization(QuantizedOp): _impl_for_op_named: str = "BatchNormalization" + def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: + """Create corresponding QuantizedArray for the output of the activation function. + + Args: + *inputs (numpy.ndarray): Calibration sample inputs. + + Returns: + numpy.ndarray: the output values for the provided calibration samples. + """ + + # Here we need the actual values of the constants, we need to pass through + # the numpy.ndarrays in the computation graph + prepared_inputs = self._prepare_inputs_with_constants( + *inputs, calibrate=True, quantize_actual_values=False + ) + + raw_result = self.call_impl(*prepared_inputs, **self.attrs) + + if not isinstance(raw_result, RawOpOutput): + # Check if batch normalization is applied per channel + scale = self.constant_inputs[self._params_name_to_input_idx["scale"]].values + bias = self.constant_inputs[self._params_name_to_input_idx["bias"]].values + if scale.size > 1 or bias.size > 1: + # Per channel batchnorm struggles with low bit-width quantization. + # Batchnorm parameters (scale/bias/mean/var) can vary significantly + # between channels. Our tensor-based quantization, however, only supports + # a global scale and offset. Errors in quantized values can thus be severe. + # To mitigate this, we use percentiles to clip extreme values. + + lower_bound = numpy.percentile(raw_result, 0.1) + upper_bound = numpy.percentile(raw_result, 99.9) + + raw_result = numpy.clip(raw_result, lower_bound, upper_bound) + + quantized_samples = QuantizedArray(self.n_bits, raw_result) + + self.output_quant_params = quantized_samples.quantizer.quant_params + self.output_quant_stats = quantized_samples.quantizer.quant_stats + + raw_result = quantized_samples.values + + return raw_result + class QuantizedFlatten(QuantizedOp): """Quantized flatten for encrypted inputs."""