diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 06e181f31..fc4e75cb9 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -124,6 +124,11 @@ def get_cached(self, attr): class WeightQuantProxyFromInjector(WeightQuantProxyFromInjectorBase): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self._cached_weight = None + self.cache_inference_quant_weight = False + @property def tracked_parameter_list(self): return [m.weight for m in self.tracked_module_list if m.weight is not None] @@ -152,11 +157,20 @@ def bit_width(self): def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width = impl(x) - return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + if self._cached_weight is not None: + out = self._cached_weight.quant_tensor + else: + impl = self.export_handler if self.export_mode else self.tensor_quant + out, scale, zero_point, bit_width = impl(x) + out = IntQuantTensor( + out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled - return x + out = x + if isinstance( + out, IntQuantTensor + ) and not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = _CachedIO(out.detach(), metadata_only=False) + return out class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): @@ -232,32 +246,6 @@ def forward( class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): - def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: - super().__init__(quant_layer, quant_injector) - self._cached_bias = None - self.cache_inference_quant_bias = False - - @property - def tracked_parameter_list(self): - return [m.bias for m in self.tracked_module_list if m.bias is not None] - - @property - def requires_input_scale(self) -> bool: - if self.is_quant_enabled: - return self.quant_injector.requires_input_scale - else: - return False - - def get_cached(self, attr): - if self._cached_bias is None: - warn( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") - return None - if self.training: - warn("Cached quant bias scale is being used in training mode.") - return getattr(self._cached_bias, attr) - def scale(self): if not self.is_quant_enabled: return None