From 368d0a37c7c537f2e02a6a26d5403af6ad91bd85 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Mar 2024 17:22:58 +0000 Subject: [PATCH] Biased mul fn --- src/brevitas/nn/quant_scale_bias.py | 15 ++++++++++----- src/brevitas/quant_tensor/torch_handler.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/brevitas/nn/quant_scale_bias.py b/src/brevitas/nn/quant_scale_bias.py index 2681d7407..a91b69525 100644 --- a/src/brevitas/nn/quant_scale_bias.py +++ b/src/brevitas/nn/quant_scale_bias.py @@ -86,20 +86,25 @@ def inner_forward_impl( quant_weight = quant_weight.view(self.runtime_shape) quant_bias = quant_bias.view(self.runtime_shape) + def biased_mul(input, weight, bias): + out = torch.mul(input, weight) + if bias is not None: + out += bias + return out + # TODO: when implementing new types of QuantTensor, this should be revised if isinstance(input, QuantTensor): from brevitas.quant_tensor.torch_handler import quant_layer + output_tensor = quant_layer( - torch.mul, + biased_mul, input, quant_weight, - bias=None, + bias=quant_bias, external_acc_bit_width_fn=self.max_acc_bit_width) else: - output_tensor = torch.mul(input, quant_weight) + output_tensor = biased_mul(input, quant_weight, quant_bias) - if quant_bias is not None: - output_tensor += quant_bias return output_tensor def max_acc_bit_width(self, input_bit_width, weight_bit_width, weight_shape): diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 2b0408948..7c1e8aaf4 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -246,7 +246,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): rescaled_value = x * reduce_size # remove avg scaling quant_input = quant_input.set(value=rescaled_value) - quant_input = quant_input.set(bit_width=max_acc_bit_width(x.bit_width, reduce_size)) + quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) return quant_input