From e31ed767189ed4c46d1ad12e874b4b33e0aff7fa Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 22 Mar 2024 18:30:59 +0000 Subject: [PATCH] biased mul --- src/brevitas/nn/quant_scale_bias.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/brevitas/nn/quant_scale_bias.py b/src/brevitas/nn/quant_scale_bias.py index 2681d7407..a91b69525 100644 --- a/src/brevitas/nn/quant_scale_bias.py +++ b/src/brevitas/nn/quant_scale_bias.py @@ -86,20 +86,25 @@ def inner_forward_impl( quant_weight = quant_weight.view(self.runtime_shape) quant_bias = quant_bias.view(self.runtime_shape) + def biased_mul(input, weight, bias): + out = torch.mul(input, weight) + if bias is not None: + out += bias + return out + # TODO: when implementing new types of QuantTensor, this should be revised if isinstance(input, QuantTensor): from brevitas.quant_tensor.torch_handler import quant_layer + output_tensor = quant_layer( - torch.mul, + biased_mul, input, quant_weight, - bias=None, + bias=quant_bias, external_acc_bit_width_fn=self.max_acc_bit_width) else: - output_tensor = torch.mul(input, quant_weight) + output_tensor = biased_mul(input, quant_weight, quant_bias) - if quant_bias is not None: - output_tensor += quant_bias return output_tensor def max_acc_bit_width(self, input_bit_width, weight_bit_width, weight_shape):