Skip to content

Commit

Permalink
fix bias caching
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 25, 2024
1 parent 4fac083 commit 9fdd9a2
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self._cached_bias = None
self.cache_inference_quant_bias = False
self.cache_inference_quant_bias_metadata_only = False
self.requires_input_scale = self.quant_injector.requires_input_scale and self.is_quant_enabled
self.requires_input_scale = self.quant_injector.requires_input_scale

@property
def tracked_parameter_list(self):
Expand Down Expand Up @@ -280,7 +280,7 @@ class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase):
def scale(self):
if not self.is_quant_enabled:
return None
if self.requires_input_scale:
if self.requires_input_scale and self.is_quant_enabled and self.is_quant_enabled:
cache = self.get_cached('scale')
return cache
zhs = self._zero_hw_sentinel()
Expand Down Expand Up @@ -313,7 +313,7 @@ def compute_bias_scale(
self,
input: Optional[Union[Tensor, IntQuantTensor]],
weight: Optional[Union[Tensor, IntQuantTensor]]) -> Optional[Tensor]:
if not self.requires_input_scale:
if not self.requires_input_scale and self.is_quant_enabled:
return None
if not isinstance(input, IntQuantTensor) or not isinstance(weight, IntQuantTensor):
return None
Expand All @@ -339,14 +339,14 @@ def forward(
out = self._cached_bias.value
else:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None:
if self.requires_input_scale and input_scale is None and self.is_quant_enabled:
input_scale = self.scale()
if input_scale is None:
raise RuntimeError("Input scale required")
elif self.requires_input_scale and input_scale is not None:
elif self.requires_input_scale and input_scale is not None and self.is_quant_enabled:
input_scale = input_scale.view(-1)

if self.requires_input_scale:
if self.requires_input_scale and self.is_quant_enabled:
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
out, out_scale, out_zp, out_bit_width = impl(x)
Expand Down

0 comments on commit 9fdd9a2

Please sign in to comment.