Skip to content

Commit c096d7b

Browse files
authored
Fix (nn/conv): Fixed conversion of convolutions when padding_mode='same' (#1017)
1 parent efc29fc commit c096d7b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/brevitas/nn/quant_conv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
dtype: Optional[torch.dtype] = None,
4747
**kwargs) -> None:
4848
# avoid an init error in the super class by setting padding to 0
49-
if padding_mode == 'zeros' and padding == 'same' and stride > 1:
49+
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
5050
padding = 0
5151
is_same_padded_strided = True
5252
else:
@@ -132,7 +132,7 @@ def __init__(
132132
dtype: Optional[torch.dtype] = None,
133133
**kwargs) -> None:
134134
# avoid an init error in the super class by setting padding to 0
135-
if padding_mode == 'zeros' and padding == 'same' and stride > 1:
135+
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
136136
padding = 0
137137
is_same_padded_strided = True
138138
else:
@@ -220,7 +220,7 @@ def __init__(
220220
dtype: Optional[torch.dtype] = None,
221221
**kwargs) -> None:
222222
# avoid an init error in the super class by setting padding to 0
223-
if padding_mode == 'zeros' and padding == 'same' and stride > 1:
223+
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
224224
padding = 0
225225
is_same_padded_strided = True
226226
else:

0 commit comments

Comments
 (0)