Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 22, 2024
1 parent 486db7f commit 92ac53b
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def _avg_scaling(self):

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

if self.export_mode:
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor):
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
if self.is_trunc_quant_enabled:
Expand All @@ -67,7 +69,9 @@ def forward(self, input: Union[Tensor, QuantTensor]):
x = x.set(bit_width=self.max_acc_bit_width(x.bit_width))
x = self.trunc_quant(x)
else:
assert not self.is_trunc_quant_enabled
x = super(TruncAvgPool2d, self).forward(x)

return self.pack_output(x)

def max_acc_bit_width(self, input_bit_width):
Expand Down Expand Up @@ -130,14 +134,17 @@ def compute_kernel_size_stride(self, input_shape, output_shape):

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(x))
self._set_global_is_quant_layer(False)
return out

if self.cache_kernel_size_stride:
self._cached_kernel_size = k_size
self._cached_kernel_stride = stride

if isinstance(x, QuantTensor):
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
Expand All @@ -148,7 +155,9 @@ def forward(self, input: Union[Tensor, QuantTensor]):
y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size))
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
y = super(TruncAdaptiveAvgPool2d, self).forward(x)

return self.pack_output(y)

def max_acc_bit_width(self, input_bit_width, reduce_size):
Expand Down

0 comments on commit 92ac53b

Please sign in to comment.