Skip to content

Commit 553d21d

Browse files
authored
[ONNX] Fix onnx MaxPool1D with indices lowering to torch IR (#4213)
This PR takes care of #4212. Signed-off-by: Zahid Wakeel <[email protected]>
1 parent 8cc313f commit 553d21d

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,9 +1276,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
12761276
if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1))
12771277
return failure();
12781278

1279-
if (rank == 3)
1280-
return rewriter.notifyMatchFailure(
1281-
binder.op, "Unimplemented: AtenMaxPool1dWithIndicesOp");
1279+
if (rank == 3) {
1280+
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool1dWithIndicesOp>(
1281+
binder.op, resultTypeOut, resultTypeIndices, operand,
1282+
kernelSizeList, stridesList, paddingList, dilationsList,
1283+
cstCeilMode);
1284+
return success();
1285+
}
12821286

12831287
if (rank == 4) {
12841288
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dWithIndicesOp>(

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,48 @@ func.func @test_multinomial_dtype_double_samplenum_4(%arg0: !torch.vtensor<[3,5]
659659

660660
// -----
661661

662+
// CHECK-LABEL: func.func @test_maxpool_1d_indices_default
663+
func.func @test_maxpool_1d_indices_default(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
664+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
665+
// CHECK: %[[VAL_0:.*]] = torch.constant.int 2
666+
// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.int) -> !torch.list<int>
667+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
668+
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
669+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
670+
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
671+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
672+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]] : (!torch.int) -> !torch.list<int>
673+
// CHECK: %[[VAL_8:.*]] = torch.constant.bool false
674+
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = torch.aten.max_pool1d_with_indices %[[ARG0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_7]], %[[VAL_8]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31],f32>, !torch.vtensor<[93],ui64>
675+
// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,3,31],f32>
676+
// CHECK: }
677+
%0:2 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,31],f32>, !torch.vtensor<[93], ui64>)
678+
return %0#0 : !torch.vtensor<[1,3,31],f32>
679+
}
680+
681+
// -----
682+
683+
// CHECK-LABEL: func.func @test_maxpool_1d_indices_ceil_pad_stride(
684+
func.func @test_maxpool_1d_indices_ceil_pad_stride(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,16],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
685+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,16],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
686+
// CHECK: %[[VAL_0:.*]] = torch.constant.int 5
687+
// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.int) -> !torch.list<int>
688+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
689+
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
690+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
691+
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
692+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
693+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]] : (!torch.int) -> !torch.list<int>
694+
// CHECK: %[[VAL_8:.*]] = torch.constant.bool true
695+
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = torch.aten.max_pool1d_with_indices %[[ARG0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_7]], %[[VAL_8]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],ui64>
696+
// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,3,16],f32>
697+
// CHECK: }
698+
%0:2 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [5 : si64], torch.onnx.pads = [2 : si64, 2: si64], torch.onnx.strides = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48], ui64>)
699+
return %0#0 : !torch.vtensor<[1,3,16],f32>
700+
}
701+
702+
// -----
703+
662704
// CHECK-LABEL: func.func @test_maxpool_2d_default
663705
func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
664706
// CHECK: %[[I2:.*]] = torch.constant.int 2

0 commit comments

Comments
 (0)