From 3f8a4fc47a403e060516a170910b230d74158c18 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 29 Aug 2024 18:58:00 +0100 Subject: [PATCH] Fix (nn/conv): Fixed conversion of convolutions when `padding_mode='same'` --- src/brevitas/nn/quant_conv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 7e2abd1ab..5c6c5afea 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -46,7 +46,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: @@ -132,7 +132,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: @@ -220,7 +220,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: