Skip to content

Commit

Permalink
fix caching bias
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 25, 2024
1 parent 9fdd9a2 commit b67534b
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,29 +335,25 @@ def forward(
out = x
input_scale = self.compute_bias_scale(input, weight)
if self.is_quant_enabled:
if self._cached_bias is not None and not self.cache_inference_quant_bias_metadata_only:
out = self._cached_bias.value
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None and self.is_quant_enabled:
input_scale = self.scale()
if input_scale is None:
raise RuntimeError("Input scale required")
elif self.requires_input_scale and input_scale is not None and self.is_quant_enabled:
input_scale = input_scale.view(-1)

if self.requires_input_scale and self.is_quant_enabled:
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None and self.is_quant_enabled:
input_scale = self.scale()
if input_scale is None:
raise RuntimeError("Input scale required")
elif self.requires_input_scale and input_scale is not None and self.is_quant_enabled:
input_scale = input_scale.view(-1)

if self.requires_input_scale and self.is_quant_enabled:
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
out, out_scale, out_zp, out_bit_width = impl(x)
if not is_dynamo_compiling():
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
if not self.training and self.cache_inference_quant_bias and self._cached_bias is not None:
cached_bias = _CachedIO(
out.detach(),
metadata_only=self.cache_inference_quant_bias_metadata_only)
self._cached_bias = cached_bias
out, out_scale, out_zp, out_bit_width = impl(x)
if not is_dynamo_compiling():
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
if not self.training and self.cache_inference_quant_bias:
cached_bias = _CachedIO(
out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only)
self._cached_bias = cached_bias
else:
out = x
return out

0 comments on commit b67534b

Please sign in to comment.