diff --git a/src/brevitas/nn/quant_scale_bias.py b/src/brevitas/nn/quant_scale_bias.py index 2681d7407..a91b69525 100644 --- a/src/brevitas/nn/quant_scale_bias.py +++ b/src/brevitas/nn/quant_scale_bias.py @@ -86,20 +86,25 @@ def inner_forward_impl( quant_weight = quant_weight.view(self.runtime_shape) quant_bias = quant_bias.view(self.runtime_shape) + def biased_mul(input, weight, bias): + out = torch.mul(input, weight) + if bias is not None: + out += bias + return out + # TODO: when implementing new types of QuantTensor, this should be revised if isinstance(input, QuantTensor): from brevitas.quant_tensor.torch_handler import quant_layer + output_tensor = quant_layer( - torch.mul, + biased_mul, input, quant_weight, - bias=None, + bias=quant_bias, external_acc_bit_width_fn=self.max_acc_bit_width) else: - output_tensor = torch.mul(input, quant_weight) + output_tensor = biased_mul(input, quant_weight, quant_bias) - if quant_bias is not None: - output_tensor += quant_bias return output_tensor def max_acc_bit_width(self, input_bit_width, weight_bit_width, weight_shape):