diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index de931db47d40..5051e5ed5cf7 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -141,7 +141,7 @@ def _compute_padding_length( def _apply_same_padding( - inputs, kernel_size, strides, operation_type, dilation_rate=1 + inputs, kernel_size, strides, data_format, operation_type, dilation_rate=1 ): """Apply same padding to the input tensor. @@ -174,7 +174,10 @@ def _apply_same_padding( spatial_shape[i], kernel_size[i], strides[i], dilation_rate[i] ) mode = "constant" - padding = (padding_size,) + padding + if data_format == "channels_last": + padding = (padding_size,) + padding + else: + padding = padding + (padding_size,) if all([left == right for left, right in padding]): return inputs, [left for left, _ in padding] @@ -252,7 +255,7 @@ def max_pool( # Torch does not natively support `"same"` padding, we need to manually # apply the right amount of padding to `inputs`. inputs, padding = _apply_same_padding( - inputs, pool_size, strides, operation_type="pooling" + inputs, pool_size, strides, data_format, operation_type="pooling" ) else: padding = 0 @@ -312,7 +315,7 @@ def average_pool( # Torch does not natively support `"same"` padding, we need to manually # apply the right amount of padding to `inputs`. inputs, padding = _apply_same_padding( - inputs, pool_size, strides, operation_type="pooling" + inputs, pool_size, strides, data_format, operation_type="pooling" ) else: padding = 0 @@ -377,6 +380,7 @@ def conv( inputs, kernel.shape[2:], strides, + data_format, operation_type="conv", dilation_rate=dilation_rate, ) diff --git a/keras/src/layers/pooling/average_pooling_test.py b/keras/src/layers/pooling/average_pooling_test.py index 3e56cfdadf29..02bbdd301989 100644 --- a/keras/src/layers/pooling/average_pooling_test.py +++ b/keras/src/layers/pooling/average_pooling_test.py @@ -174,6 +174,7 @@ def test_average_pooling1d( (2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)), ((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)), ((2, 3), (2, 2), "same", "channels_last", (3, 5, 5, 4), (3, 3, 3, 4)), + ((2, 3), (3, 3), "same", "channels_first", (3, 5, 5, 4), (3, 5, 2, 2)), ) def test_average_pooling2d( self, diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 4d4262e830f0..0eededaf0bda 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -1381,6 +1381,18 @@ def test_average_pool_same_padding(self): knn.average_pool(x, 2, (2, 1), padding="same"), np_avgpool2d(x, 2, (2, 1), padding="same", data_format=data_format), ) + # Test 2D average pooling with different pool size. + if data_format == "channels_last": + input_shape = (2, 10, 9, 3) + else: + input_shape = (2, 3, 10, 9) + x = np.arange(540, dtype=float).reshape(input_shape) + self.assertAllClose( + knn.average_pool(x, (2, 3), (3, 3), padding="same"), + np_avgpool2d( + x, (2, 3), (3, 3), padding="same", data_format=data_format + ), + ) @parameterized.product( strides=(1, 2, 3),