Skip to content

Commit

Permalink
updated convtranspose method based on PR suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev committed Mar 6, 2024
1 parent 4c76e4f commit 5b94089
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1)
max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size
patch_size = (self.kernel_size[0] // self.stride[0])
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width

Expand Down Expand Up @@ -215,9 +215,9 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1)
overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1)
max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size
patch_size = (self.kernel_size[0] //
self.stride[0]) * (self.kernel_size[1] // self.stride[1])
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width

Expand Down Expand Up @@ -313,9 +313,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1)
overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1)
overlapping_sums *= max(round(self.kernel_size[2] / self.stride[2]), 1)
max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size
patch_size = (self.kernel_size[0] // self.stride[0]) * (
self.kernel_size[1] // self.stride[1]) * (self.kernel_size[2] // self.stride[2])
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width

0 comments on commit 5b94089

Please sign in to comment.