From 9fdd9a293d93457e4704d7a43f2d32b1ce63eaea Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 25 Aug 2024 15:15:09 +0100 Subject: [PATCH] fix bias caching --- src/brevitas/proxy/parameter_quant.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 6a3e024e2..ec533351a 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -154,7 +154,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self._cached_bias = None self.cache_inference_quant_bias = False self.cache_inference_quant_bias_metadata_only = False - self.requires_input_scale = self.quant_injector.requires_input_scale and self.is_quant_enabled + self.requires_input_scale = self.quant_injector.requires_input_scale @property def tracked_parameter_list(self): @@ -280,7 +280,7 @@ class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): def scale(self): if not self.is_quant_enabled: return None - if self.requires_input_scale: + if self.requires_input_scale and self.is_quant_enabled and self.is_quant_enabled: cache = self.get_cached('scale') return cache zhs = self._zero_hw_sentinel() @@ -313,7 +313,7 @@ def compute_bias_scale( self, input: Optional[Union[Tensor, IntQuantTensor]], weight: Optional[Union[Tensor, IntQuantTensor]]) -> Optional[Tensor]: - if not self.requires_input_scale: + if not self.requires_input_scale and self.is_quant_enabled: return None if not isinstance(input, IntQuantTensor) or not isinstance(weight, IntQuantTensor): return None @@ -339,14 +339,14 @@ def forward( out = self._cached_bias.value else: impl = self.export_handler if self.export_mode else self.tensor_quant - if self.requires_input_scale and input_scale is None: + if self.requires_input_scale and input_scale is None and self.is_quant_enabled: input_scale = self.scale() if input_scale is None: raise RuntimeError("Input scale required") - elif self.requires_input_scale and input_scale is not None: + elif self.requires_input_scale and input_scale is not None and self.is_quant_enabled: input_scale = input_scale.view(-1) - if self.requires_input_scale: + if self.requires_input_scale and self.is_quant_enabled: out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x)