@@ -65,6 +65,78 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
65
65
return %3 : !torch.vtensor <[?,?,?,?],f32 >
66
66
}
67
67
68
+ // -----
69
+
70
+ // CHECK-LABEL: func.func @torch.aten.max_pool2d$ceiloff(
71
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,256,56,56],f32>) -> !torch.vtensor<[1,256,27,27],f32> {
72
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,256,56,56],f32> -> tensor<1x256x56x56xf32>
73
+ // CHECK: %int3 = torch.constant.int 3
74
+ // CHECK: %int2 = torch.constant.int 2
75
+ // CHECK: %int1 = torch.constant.int 1
76
+ // CHECK: %false = torch.constant.bool false
77
+ // CHECK: %int0 = torch.constant.int 0
78
+ // CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
79
+ // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
80
+ // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
81
+ // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
82
+ // CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
83
+ // CHECK{LITERAL}: <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
84
+ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
85
+ // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
86
+ // CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
87
+ // CHECK: }) : (tensor<1x256x56x56xf32>, tensor<f32>) -> tensor<1x256x27x27xf32>
88
+ // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x256x27x27xf32> -> !torch.vtensor<[1,256,27,27],f32>
89
+ // CHECK: return %[[VAL_7]] : !torch.vtensor<[1,256,27,27],f32>
90
+ func.func @torch.aten.max_pool2d$ceiloff (%arg0: !torch.vtensor <[1 ,256 ,56 ,56 ],f32 >) -> !torch.vtensor <[1 ,256 ,27 ,27 ],f32 > {
91
+ %int3 = torch.constant.int 3
92
+ %int2 = torch.constant.int 2
93
+ %int1 = torch.constant.int 1
94
+ %false = torch.constant.bool false
95
+ %int0 = torch.constant.int 0
96
+ %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
97
+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
98
+ %2 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
99
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
100
+ %4 = torch.aten.max_pool2d %arg0 , %0 , %1 , %2 , %3 , %false : !torch.vtensor <[1 ,256 ,56 ,56 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool -> !torch.vtensor <[1 ,256 ,27 ,27 ],f32 >
101
+ return %4 : !torch.vtensor <[1 ,256 ,27 ,27 ],f32 >
102
+ }
103
+
104
+ // -----
105
+
106
+ // CHECK-LABEL: func.func @torch.aten.max_pool2d$ceilon(
107
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,256,56,56],f32>) -> !torch.vtensor<[1,256,28,28],f32> {
108
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,256,56,56],f32> -> tensor<1x256x56x56xf32>
109
+ // CHECK: %int3 = torch.constant.int 3
110
+ // CHECK: %int2 = torch.constant.int 2
111
+ // CHECK: %int1 = torch.constant.int 1
112
+ // CHECK: %true = torch.constant.bool true
113
+ // CHECK: %int0 = torch.constant.int 0
114
+ // CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
115
+ // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
116
+ // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
117
+ // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
118
+ // CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
119
+ // CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
120
+ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
121
+ // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
122
+ // CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
123
+ // CHECK: }) : (tensor<1x256x56x56xf32>, tensor<f32>) -> tensor<1x256x28x28xf32>
124
+ // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x256x28x28xf32> -> !torch.vtensor<[1,256,28,28],f32>
125
+ // CHECK: return %[[VAL_7]] : !torch.vtensor<[1,256,28,28],f32>
126
+ func.func @torch.aten.max_pool2d$ceilon (%arg0: !torch.vtensor <[1 ,256 ,56 ,56 ],f32 >) -> !torch.vtensor <[1 ,256 ,28 ,28 ],f32 > {
127
+ %int3 = torch.constant.int 3
128
+ %int2 = torch.constant.int 2
129
+ %int1 = torch.constant.int 1
130
+ %true = torch.constant.bool true
131
+ %int0 = torch.constant.int 0
132
+ %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
133
+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
134
+ %2 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
135
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
136
+ %4 = torch.aten.max_pool2d %arg0 , %0 , %1 , %2 , %3 , %true : !torch.vtensor <[1 ,256 ,56 ,56 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool -> !torch.vtensor <[1 ,256 ,28 ,28 ],f32 >
137
+ return %4 : !torch.vtensor <[1 ,256 ,28 ,28 ],f32 >
138
+ }
139
+
68
140
69
141
// -----
70
142
0 commit comments