diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 7e2abd1ab..5c6c5afea 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -46,7 +46,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: @@ -132,7 +132,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: @@ -220,7 +220,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: