Skip to content

Commit 69d7eff

Browse files
fix(ops): Fix inconsistent padding calculation in PyTorch backend ops (#20774)
* Fix "same" padding torch issue * format * fix type * add condition for channels first and last * fix(ops): Fix inconsistent padding calculation in PyTorch backend ops Was able to still reproduce the error, the PyTorch backend had inconsistent behavior between static shape inference and dynamic execution for pooling operations, particularly with 'same' padding and non-unit strides, figured that the root cause was by incorrect padding calculation logic that didn't properly handle asymmetric padding cases. Key changes: - Rewrote _compute_padding_length() to handle stride-based padding - Fixed padding calculation to properly support asymmetric padding cases - Standardize channels_first/channels_last conversion in pooling ops - Cleaned up padding application in _apply_same_padding() - Added proper handling of data_format throughout pooling pipeline This fixes the issue where MaxPooling2D with 'same' padding would produce different shapes between compute_output_shape() and actual execution (e.g. (1,5,2,2) vs (1,5,2,1)). Rebased on top of Sachin's September 2024 PR to incorporate latest keras:master changes. --------- Co-authored-by: sachin prasad <[email protected]>
1 parent dca9e61 commit 69d7eff

File tree

3 files changed

+108
-67
lines changed

3 files changed

+108
-67
lines changed

keras/src/backend/torch/nn.py

Lines changed: 95 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch.nn.functional as tnn
33

44
from keras.src import backend
5-
from keras.src import tree
65
from keras.src.backend.common.backend_utils import (
76
compute_conv_transpose_padding_args_for_torch,
87
)
@@ -204,17 +203,27 @@ def sparsemax(logits, axis=-1):
204203
def _compute_padding_length(
205204
input_length, kernel_length, stride, dilation_rate=1
206205
):
207-
"""Compute padding length along one dimension."""
208-
total_padding_length = (
209-
dilation_rate * (kernel_length - 1) - (input_length - 1) % stride
210-
)
211-
left_padding = total_padding_length // 2
212-
right_padding = (total_padding_length + 1) // 2
206+
"""Compute padding length along one dimension with support
207+
for asymmetric padding."""
208+
effective_k_size = (kernel_length - 1) * dilation_rate + 1
209+
if stride == 1:
210+
# total padding is kernel_size - 1
211+
total_padding = effective_k_size - 1
212+
else:
213+
# calc. needed padding for case with stride involved
214+
output_size = (input_length + stride - 1) // stride
215+
total_padding = max(
216+
0, (output_size - 1) * stride + effective_k_size - input_length
217+
)
218+
219+
# divide padding evenly, with extra pixel going at the end if needed
220+
left_padding = total_padding // 2
221+
right_padding = total_padding - left_padding
213222
return (left_padding, right_padding)
214223

215224

216225
def _apply_same_padding(
217-
inputs, kernel_size, strides, operation_type, dilation_rate=1
226+
inputs, kernel_size, strides, data_format, operation_type, dilation_rate=1
218227
):
219228
"""Apply same padding to the input tensor.
220229
@@ -231,50 +240,49 @@ def _apply_same_padding(
231240
"""
232241
spatial_shape = inputs.shape[2:]
233242
num_spatial_dims = len(spatial_shape)
234-
padding = ()
243+
padding = []
244+
245+
if operation_type != "pooling":
246+
dilation_rate = standardize_tuple(
247+
dilation_rate, num_spatial_dims, "dilation_rate"
248+
)
235249

236250
for i in range(num_spatial_dims):
237-
if operation_type == "pooling":
238-
padding_size = _compute_padding_length(
239-
spatial_shape[i], kernel_size[i], strides[i]
240-
)
241-
mode = "replicate"
242-
else:
243-
dilation_rate = standardize_tuple(
244-
dilation_rate, num_spatial_dims, "dilation_rate"
245-
)
246-
padding_size = _compute_padding_length(
247-
spatial_shape[i], kernel_size[i], strides[i], dilation_rate[i]
248-
)
249-
mode = "constant"
250-
padding = (padding_size,) + padding
251+
dil = 1 if operation_type == "pooling" else dilation_rate[i]
252+
pad = _compute_padding_length(
253+
spatial_shape[i], kernel_size[i], strides[i], dil
254+
)
255+
padding.append(pad)
251256

252-
if all([left == right for left, right in padding]):
257+
# convert padding to torch format
258+
if all(left == right for left, right in padding):
253259
return inputs, [left for left, _ in padding]
254260

255-
flattened_padding = tuple(
256-
value for left_and_right in padding for value in left_and_right
257-
)
258-
return tnn.pad(inputs, pad=flattened_padding, mode=mode), 0
261+
# else, need to pad manually
262+
flattened_padding = []
263+
for pad in reversed(padding):
264+
flattened_padding.extend(pad)
265+
266+
mode = "replicate" if operation_type == "pooling" else "constant"
267+
return tnn.pad(inputs, pad=tuple(flattened_padding), mode=mode), 0
259268

260269

261270
def _transpose_spatial_inputs(inputs):
262-
num_spatial_dims = inputs.ndim - 2
271+
"""Transpose inputs from channels_last to channels_first format."""
263272
# Torch pooling does not support `channels_last` format, so
264273
# we need to transpose to `channels_first` format.
265-
if num_spatial_dims == 1:
266-
inputs = torch.permute(inputs, (0, 2, 1))
267-
elif num_spatial_dims == 2:
268-
inputs = torch.permute(inputs, (0, 3, 1, 2))
269-
elif num_spatial_dims == 3:
270-
inputs = torch.permute(inputs, (0, 4, 1, 2, 3))
271-
else:
272-
raise ValueError(
273-
"Inputs must have ndim=3, 4 or 5, "
274-
"corresponding to 1D, 2D and 3D inputs. "
275-
f"Received input shape: {inputs.shape}."
276-
)
277-
return inputs
274+
ndim = inputs.ndim - 2
275+
if ndim == 1: # 1D case
276+
return torch.permute(inputs, (0, 2, 1))
277+
elif ndim == 2: # 2D case
278+
return torch.permute(inputs, (0, 3, 1, 2))
279+
elif ndim == 3: # 3D case
280+
return torch.permute(inputs, (0, 4, 1, 2, 3))
281+
raise ValueError(
282+
"Inputs must have ndim=3, 4 or 5, "
283+
"corresponding to 1D, 2D and 3D inputs. "
284+
f"Received input shape: {inputs.shape}."
285+
)
278286

279287

280288
def _transpose_spatial_outputs(outputs):
@@ -309,6 +317,7 @@ def max_pool(
309317
padding="valid",
310318
data_format=None,
311319
):
320+
"""Fixed max pooling implementation."""
312321
inputs = convert_to_tensor(inputs)
313322
num_spatial_dims = inputs.ndim - 2
314323
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
@@ -325,7 +334,7 @@ def max_pool(
325334
# Torch does not natively support `"same"` padding, we need to manually
326335
# apply the right amount of padding to `inputs`.
327336
inputs, padding = _apply_same_padding(
328-
inputs, pool_size, strides, operation_type="pooling"
337+
inputs, pool_size, strides, data_format, "pooling"
329338
)
330339
else:
331340
padding = 0
@@ -370,26 +379,36 @@ def average_pool(
370379
padding="valid",
371380
data_format=None,
372381
):
382+
"""Fixed average pooling with correct padding calculation."""
373383
inputs = convert_to_tensor(inputs)
374384
num_spatial_dims = inputs.ndim - 2
375385
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
376-
if strides is None:
377-
strides = pool_size
378-
else:
379-
strides = standardize_tuple(strides, num_spatial_dims, "strides")
386+
strides = (
387+
pool_size
388+
if strides is None
389+
else standardize_tuple(strides, num_spatial_dims, "strides")
390+
)
380391

381392
data_format = backend.standardize_data_format(data_format)
393+
orig_format = data_format
394+
382395
if data_format == "channels_last":
383396
inputs = _transpose_spatial_inputs(inputs)
397+
384398
if padding == "same":
385399
# Torch does not natively support `"same"` padding, we need to manually
386400
# apply the right amount of padding to `inputs`.
387401
inputs, padding = _apply_same_padding(
388-
inputs, pool_size, strides, operation_type="pooling"
402+
inputs,
403+
pool_size,
404+
strides,
405+
"channels_first", # we're in channels_first here
406+
"pooling",
389407
)
390408
else:
391409
padding = 0
392410

411+
# apply pooling
393412
if num_spatial_dims == 1:
394413
outputs = tnn.avg_pool1d(
395414
inputs,
@@ -420,8 +439,10 @@ def average_pool(
420439
"corresponding to 1D, 2D and 3D inputs. "
421440
f"Received input shape: {inputs.shape}."
422441
)
423-
if data_format == "channels_last":
442+
443+
if orig_format == "channels_last":
424444
outputs = _transpose_spatial_outputs(outputs)
445+
425446
return outputs
426447

427448

@@ -433,6 +454,7 @@ def conv(
433454
data_format=None,
434455
dilation_rate=1,
435456
):
457+
"""Convolution with fixed group handling."""
436458
inputs = convert_to_tensor(inputs)
437459
kernel = convert_to_tensor(kernel)
438460
num_spatial_dims = inputs.ndim - 2
@@ -441,53 +463,59 @@ def conv(
441463
data_format = backend.standardize_data_format(data_format)
442464
if data_format == "channels_last":
443465
inputs = _transpose_spatial_inputs(inputs)
444-
# Transpose kernel from keras format to torch format.
466+
445467
kernel = _transpose_conv_kernel(kernel)
446-
if padding == "same" and any(d != 1 for d in tree.flatten(strides)):
447-
# Torch does not support this case in conv2d().
448-
# Manually pad the tensor.
468+
469+
# calc. groups snippet
470+
in_channels = inputs.shape[1]
471+
kernel_in_channels = kernel.shape[1]
472+
if in_channels % kernel_in_channels != 0:
473+
raise ValueError(
474+
f"Input channels ({in_channels}) must be divisible by "
475+
f"kernel input channels ({kernel_in_channels})"
476+
)
477+
groups = in_channels // kernel_in_channels
478+
479+
# handle padding
480+
if padding == "same":
449481
inputs, padding = _apply_same_padding(
450482
inputs,
451483
kernel.shape[2:],
452484
strides,
453-
operation_type="conv",
454-
dilation_rate=dilation_rate,
455-
)
456-
channels = inputs.shape[1]
457-
kernel_in_channels = kernel.shape[1]
458-
if channels % kernel_in_channels > 0:
459-
raise ValueError(
460-
"The number of input channels must be evenly divisible by "
461-
f"kernel.shape[1]. Received: inputs.shape={inputs.shape}, "
462-
f"kernel.shape={kernel.shape}"
485+
data_format,
486+
"conv",
487+
dilation_rate,
463488
)
464-
groups = channels // kernel_in_channels
489+
else:
490+
padding = 0
491+
492+
# apply convolution
465493
if num_spatial_dims == 1:
466494
outputs = tnn.conv1d(
467495
inputs,
468496
kernel,
469497
stride=strides,
498+
padding=padding,
470499
dilation=dilation_rate,
471500
groups=groups,
472-
padding=padding,
473501
)
474502
elif num_spatial_dims == 2:
475503
outputs = tnn.conv2d(
476504
inputs,
477505
kernel,
478506
stride=strides,
507+
padding=padding,
479508
dilation=dilation_rate,
480509
groups=groups,
481-
padding=padding,
482510
)
483511
elif num_spatial_dims == 3:
484512
outputs = tnn.conv3d(
485513
inputs,
486514
kernel,
487515
stride=strides,
516+
padding=padding,
488517
dilation=dilation_rate,
489518
groups=groups,
490-
padding=padding,
491519
)
492520
else:
493521
raise ValueError(

keras/src/layers/pooling/average_pooling_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def test_average_pooling1d(
174174
(2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)),
175175
((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)),
176176
((2, 3), (2, 2), "same", "channels_last", (3, 5, 5, 4), (3, 3, 3, 4)),
177+
((2, 3), (3, 3), "same", "channels_first", (3, 5, 5, 4), (3, 5, 2, 2)),
177178
)
178179
def test_average_pooling2d(
179180
self,

keras/src/ops/nn_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,18 @@ def test_average_pool_same_padding(self):
15971597
knn.average_pool(x, 2, (2, 1), padding="same"),
15981598
np_avgpool2d(x, 2, (2, 1), padding="same", data_format=data_format),
15991599
)
1600+
# Test 2D average pooling with different pool size.
1601+
if data_format == "channels_last":
1602+
input_shape = (2, 10, 9, 3)
1603+
else:
1604+
input_shape = (2, 3, 10, 9)
1605+
x = np.arange(540, dtype=float).reshape(input_shape)
1606+
self.assertAllClose(
1607+
knn.average_pool(x, (2, 3), (3, 3), padding="same"),
1608+
np_avgpool2d(
1609+
x, (2, 3), (3, 3), padding="same", data_format=data_format
1610+
),
1611+
)
16001612

16011613
@parameterized.product(
16021614
strides=(1, 2, 3),

0 commit comments

Comments
 (0)