diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 604a43c00..1f9cc62b6 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -125,8 +125,6 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: self._cached_weight = self.cache_class( out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) - else: - out = out[0] else: # quantization disabled out = x return out