Skip to content

Commit

Permalink
chore: fix batchn norm per channel with clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Jun 4, 2024
1 parent 1b5ce84 commit da29d17
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/concrete/ml/quantization/quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,49 @@ class QuantizedBatchNormalization(QuantizedOp):

_impl_for_op_named: str = "BatchNormalization"

def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray:
"""Create corresponding QuantizedArray for the output of the activation function.
Args:
*inputs (numpy.ndarray): Calibration sample inputs.
Returns:
numpy.ndarray: the output values for the provided calibration samples.
"""

# Here we need the actual values of the constants, we need to pass through
# the numpy.ndarrays in the computation graph
prepared_inputs = self._prepare_inputs_with_constants(
*inputs, calibrate=True, quantize_actual_values=False
)

raw_result = self.call_impl(*prepared_inputs, **self.attrs)

if not isinstance(raw_result, RawOpOutput):
# Check if batch normalization is applied per channel
scale = self.constant_inputs[self._params_name_to_input_idx["scale"]].values
bias = self.constant_inputs[self._params_name_to_input_idx["bias"]].values
if scale.size > 1 or bias.size > 1:
# Per channel batchnorm struggles with low bit-width quantization.
# Batchnorm parameters (scale/bias/mean/var) can vary significantly
# between channels. Our tensor-based quantization, however, only supports
# a global scale and offset. Errors in quantized values can thus be severe.
# To mitigate this, we use percentiles to clip extreme values.

lower_bound = numpy.percentile(raw_result, 0.1)
upper_bound = numpy.percentile(raw_result, 99.9)

raw_result = numpy.clip(raw_result, lower_bound, upper_bound)

quantized_samples = QuantizedArray(self.n_bits, raw_result)

self.output_quant_params = quantized_samples.quantizer.quant_params
self.output_quant_stats = quantized_samples.quantizer.quant_stats

raw_result = quantized_samples.values

return raw_result


class QuantizedFlatten(QuantizedOp):
"""Quantized flatten for encrypted inputs."""
Expand Down

0 comments on commit da29d17

Please sign in to comment.