@@ -659,6 +659,48 @@ func.func @test_multinomial_dtype_double_samplenum_4(%arg0: !torch.vtensor<[3,5]
659
659
660
660
// -----
661
661
662
+ // CHECK-LABEL: func.func @test_maxpool_1d_indices_default
663
+ func.func @test_maxpool_1d_indices_default (%arg0: !torch.vtensor <[1 ,3 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,31 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
664
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
665
+ // CHECK: %[[VAL_0:.*]] = torch.constant.int 2
666
+ // CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.int) -> !torch.list<int>
667
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 0
668
+ // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
669
+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 1
670
+ // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
671
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 1
672
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]] : (!torch.int) -> !torch.list<int>
673
+ // CHECK: %[[VAL_8:.*]] = torch.constant.bool false
674
+ // CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = torch.aten.max_pool1d_with_indices %[[ARG0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_7]], %[[VAL_8]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31],f32>, !torch.vtensor<[93],ui64>
675
+ // CHECK: return %[[VAL_9]] : !torch.vtensor<[1,3,31],f32>
676
+ // CHECK: }
677
+ %0:2 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.kernel_shape = [2 : si64 ]} : (!torch.vtensor <[1 ,3 ,32 ],f32 >) -> (!torch.vtensor <[1 ,3 ,31 ],f32 >, !torch.vtensor <[93 ], ui64 >)
678
+ return %0#0 : !torch.vtensor <[1 ,3 ,31 ],f32 >
679
+ }
680
+
681
+ // -----
682
+
683
+ // CHECK-LABEL: func.func @test_maxpool_1d_indices_ceil_pad_stride(
684
+ func.func @test_maxpool_1d_indices_ceil_pad_stride (%arg0: !torch.vtensor <[1 ,3 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,16 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
685
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,16],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
686
+ // CHECK: %[[VAL_0:.*]] = torch.constant.int 5
687
+ // CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.int) -> !torch.list<int>
688
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 2
689
+ // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
690
+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 2
691
+ // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
692
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 1
693
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]] : (!torch.int) -> !torch.list<int>
694
+ // CHECK: %[[VAL_8:.*]] = torch.constant.bool true
695
+ // CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = torch.aten.max_pool1d_with_indices %[[ARG0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_7]], %[[VAL_8]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],ui64>
696
+ // CHECK: return %[[VAL_9]] : !torch.vtensor<[1,3,16],f32>
697
+ // CHECK: }
698
+ %0:2 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.ceil_mode = 1 : si64 , torch.onnx.kernel_shape = [5 : si64 ], torch.onnx.pads = [2 : si64 , 2 : si64 ], torch.onnx.strides = [2 : si64 ]} : (!torch.vtensor <[1 ,3 ,32 ],f32 >) -> (!torch.vtensor <[1 ,3 ,16 ],f32 >, !torch.vtensor <[48 ], ui64 >)
699
+ return %0#0 : !torch.vtensor <[1 ,3 ,16 ],f32 >
700
+ }
701
+
702
+ // -----
703
+
662
704
// CHECK-LABEL: func.func @test_maxpool_2d_default
663
705
func.func @test_maxpool_2d_default (%arg0: !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,31 ,31 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
664
706
// CHECK: %[[I2:.*]] = torch.constant.int 2
0 commit comments