From 5b94089baa50fc6ee690bb7722c0b732630ca218 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:38:44 +0000 Subject: [PATCH] updated convtranspose method based on PR suggestion --- src/brevitas/nn/quant_convtranspose.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index 0a031b517..ab7e1cafe 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -118,8 +118,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) - max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + patch_size = (self.kernel_size[0] // self.stride[0]) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -215,9 +215,9 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) - overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1) - max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + patch_size = (self.kernel_size[0] // + self.stride[0]) * (self.kernel_size[1] // self.stride[1]) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -313,9 +313,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) - overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1) - overlapping_sums *= max(round(self.kernel_size[2] / self.stride[2]), 1) - max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + patch_size = (self.kernel_size[0] // self.stride[0]) * ( + self.kernel_size[1] // self.stride[1]) * (self.kernel_size[2] // self.stride[2]) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width