diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index ec533351a..5cf1e0372 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -335,29 +335,25 @@ def forward( out = x input_scale = self.compute_bias_scale(input, weight) if self.is_quant_enabled: - if self._cached_bias is not None and not self.cache_inference_quant_bias_metadata_only: - out = self._cached_bias.value + impl = self.export_handler if self.export_mode else self.tensor_quant + if self.requires_input_scale and input_scale is None and self.is_quant_enabled: + input_scale = self.scale() + if input_scale is None: + raise RuntimeError("Input scale required") + elif self.requires_input_scale and input_scale is not None and self.is_quant_enabled: + input_scale = input_scale.view(-1) + + if self.requires_input_scale and self.is_quant_enabled: + out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: - impl = self.export_handler if self.export_mode else self.tensor_quant - if self.requires_input_scale and input_scale is None and self.is_quant_enabled: - input_scale = self.scale() - if input_scale is None: - raise RuntimeError("Input scale required") - elif self.requires_input_scale and input_scale is not None and self.is_quant_enabled: - input_scale = input_scale.view(-1) - - if self.requires_input_scale and self.is_quant_enabled: - out, out_scale, out_zp, out_bit_width = impl(x, input_scale) - else: - out, out_scale, out_zp, out_bit_width = impl(x) - if not is_dynamo_compiling(): - out = IntQuantTensor( - out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) - if not self.training and self.cache_inference_quant_bias and self._cached_bias is not None: - cached_bias = _CachedIO( - out.detach(), - metadata_only=self.cache_inference_quant_bias_metadata_only) - self._cached_bias = cached_bias + out, out_scale, out_zp, out_bit_width = impl(x) + if not is_dynamo_compiling(): + out = IntQuantTensor( + out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + if not self.training and self.cache_inference_quant_bias: + cached_bias = _CachedIO( + out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) + self._cached_bias = cached_bias else: out = x return out