Skip to content

Commit

Permalink
Fix (nn/conv): Fixed conversion of convolutions when `padding_mode='s…
Browse files Browse the repository at this point in the history
…ame'`
  • Loading branch information
nickfraser committed Aug 29, 2024
1 parent 21537ef commit 3f8a4fc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and stride > 1:
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
padding = 0
is_same_padded_strided = True
else:
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and stride > 1:
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
padding = 0
is_same_padded_strided = True
else:
Expand Down Expand Up @@ -220,7 +220,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and stride > 1:
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
padding = 0
is_same_padded_strided = True
else:
Expand Down

0 comments on commit 3f8a4fc

Please sign in to comment.