Skip to content

Commit 958d2c8

Browse files
authored
Use ceilMode from pytorch MaxPool in shlo to get correct output shape (#4167)
Pytorch passes an argument `ceilMode` into MaxPool op. Currently, torch to stablehlo conversion does not use it to determine stablehlo output shape. This change intends to use the same logic as torch to use ceilMode when calculating output shape. During conversion, we make up the size difference using padding. Changes: - Use `ceilMode` param the way torch uses it. - Added tests for both floor and ceil cases.
1 parent 66aa960 commit 958d2c8

File tree

2 files changed

+116
-16
lines changed

2 files changed

+116
-16
lines changed

lib/Conversion/TorchToStablehlo/Pooling.cpp

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -478,24 +478,52 @@ class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
478478
Value initVal =
479479
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
480480

481-
if (Dim == 1) {
482-
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
483-
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
484-
} else if (Dim == 2) {
485-
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
486-
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
487-
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
488-
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
489-
} else if (Dim == 3) {
490-
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
491-
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
492-
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
493-
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
494-
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
495-
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
496-
} else {
481+
if (Dim < 1 || Dim > 3) {
497482
assert(false && "Unsupported pooling dimension");
498483
}
484+
485+
const size_t spatialIdxStart = inputRank - Dim;
486+
487+
for (int i = 0; i < Dim; i++) {
488+
const size_t frontPadIdx = (spatialIdxStart + i) * 2;
489+
const size_t backPadIdx = (spatialIdxStart + i) * 2 + 1;
490+
491+
// torch padding is symmetric
492+
stablehloPadding[frontPadIdx] = padding[i];
493+
stablehloPadding[backPadIdx] = padding[i];
494+
495+
if (ceilMode) {
496+
// Match PyTorch output shape with extra padding. See
497+
// https://github.com/pytorch/pytorch/blob/c5de6ff079e3e5b453d6ff5190c90f02db458928/aten/src/ATen/native/Pool.h#L79
498+
// PyTorch output size formula:
499+
// 1. Calculate base output size:
500+
// output = (input + 2*pad - dilation*(kernel-1) - 1+adj) / stride + 1
501+
// where adj = (stride-1) if ceil_mode else 0
502+
// 2. Apply edge case correction:
503+
// if ((output-1) * stride >= input + pad_l) --output;
504+
505+
const int64_t inputSize = inputTy.getDimSize(spatialIdxStart + i);
506+
const int64_t numerator = (inputSize + 2 * padding[i] -
507+
dilation[i] * (kernelSize[i] - 1) - 1);
508+
const int64_t floor_output_size = (numerator) / stride[i] + 1;
509+
const int64_t adj = (stride[i] - 1);
510+
int64_t ceil_output_size = std::ceil((numerator + adj) / stride[i]) + 1;
511+
512+
// Ensure last pooling starts inside input
513+
if ((ceil_output_size - 1) * stride[i] >= inputSize + padding[i]) {
514+
ceil_output_size--;
515+
}
516+
517+
// Add extra padding to make output size same as torch
518+
if (ceil_output_size > floor_output_size) {
519+
const int64_t sizeDiff = ceil_output_size - floor_output_size;
520+
const int64_t extraPadding = sizeDiff * stride[i];
521+
stablehloPadding[frontPadIdx] += extraPadding / 2;
522+
stablehloPadding[backPadIdx] += extraPadding - extraPadding / 2;
523+
}
524+
}
525+
}
526+
499527
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
500528
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
501529
DenseI64ArrayAttr baseDilations;

test/Conversion/TorchToStablehlo/pooling.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,78 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
6565
return %3 : !torch.vtensor<[?,?,?,?],f32>
6666
}
6767

68+
// -----
69+
70+
// CHECK-LABEL: func.func @torch.aten.max_pool2d$ceiloff(
71+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,256,56,56],f32>) -> !torch.vtensor<[1,256,27,27],f32> {
72+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,256,56,56],f32> -> tensor<1x256x56x56xf32>
73+
// CHECK: %int3 = torch.constant.int 3
74+
// CHECK: %int2 = torch.constant.int 2
75+
// CHECK: %int1 = torch.constant.int 1
76+
// CHECK: %false = torch.constant.bool false
77+
// CHECK: %int0 = torch.constant.int 0
78+
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
79+
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
80+
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
81+
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
82+
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
83+
// CHECK{LITERAL}: <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
84+
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
85+
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
86+
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
87+
// CHECK: }) : (tensor<1x256x56x56xf32>, tensor<f32>) -> tensor<1x256x27x27xf32>
88+
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x256x27x27xf32> -> !torch.vtensor<[1,256,27,27],f32>
89+
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,256,27,27],f32>
90+
func.func @torch.aten.max_pool2d$ceiloff(%arg0: !torch.vtensor<[1,256,56,56],f32>) -> !torch.vtensor<[1,256,27,27],f32> {
91+
%int3 = torch.constant.int 3
92+
%int2 = torch.constant.int 2
93+
%int1 = torch.constant.int 1
94+
%false = torch.constant.bool false
95+
%int0 = torch.constant.int 0
96+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
97+
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
98+
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
99+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
100+
%4 = torch.aten.max_pool2d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,256,56,56],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,256,27,27],f32>
101+
return %4 : !torch.vtensor<[1,256,27,27],f32>
102+
}
103+
104+
// -----
105+
106+
// CHECK-LABEL: func.func @torch.aten.max_pool2d$ceilon(
107+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,256,56,56],f32>) -> !torch.vtensor<[1,256,28,28],f32> {
108+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,256,56,56],f32> -> tensor<1x256x56x56xf32>
109+
// CHECK: %int3 = torch.constant.int 3
110+
// CHECK: %int2 = torch.constant.int 2
111+
// CHECK: %int1 = torch.constant.int 1
112+
// CHECK: %true = torch.constant.bool true
113+
// CHECK: %int0 = torch.constant.int 0
114+
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
115+
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
116+
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
117+
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
118+
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
119+
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
120+
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
121+
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
122+
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
123+
// CHECK: }) : (tensor<1x256x56x56xf32>, tensor<f32>) -> tensor<1x256x28x28xf32>
124+
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x256x28x28xf32> -> !torch.vtensor<[1,256,28,28],f32>
125+
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,256,28,28],f32>
126+
func.func @torch.aten.max_pool2d$ceilon(%arg0: !torch.vtensor<[1,256,56,56],f32>) -> !torch.vtensor<[1,256,28,28],f32> {
127+
%int3 = torch.constant.int 3
128+
%int2 = torch.constant.int 2
129+
%int1 = torch.constant.int 1
130+
%true = torch.constant.bool true
131+
%int0 = torch.constant.int 0
132+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
133+
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
134+
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
135+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
136+
%4 = torch.aten.max_pool2d %arg0, %0, %1, %2, %3, %true : !torch.vtensor<[1,256,56,56],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,256,28,28],f32>
137+
return %4 : !torch.vtensor<[1,256,28,28],f32>
138+
}
139+
68140

69141
// -----
70142

0 commit comments

Comments
 (0)