Skip to content

Commit 3e34fe8

Browse files
committed
[mlir][arith] Fold arith.cmpi eq, %val, %one : i1 -> %val
1 parent 52bffdf commit 3e34fe8

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
18651865
getPredicate() == arith::CmpIPredicate::ne)
18661866
return extOp.getOperand();
18671867
}
1868+
1869+
// arith.cmpi ne, %val, %zero : i1 -> %val
1870+
if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1871+
getPredicate() == arith::CmpIPredicate::ne)
1872+
return getLhs();
1873+
}
1874+
1875+
if (matchPattern(adaptor.getRhs(), m_One())) {
1876+
// arith.cmpi eq, %val, %one : i1 -> %val
1877+
if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1878+
getPredicate() == arith::CmpIPredicate::eq)
1879+
return getLhs();
18681880
}
18691881

18701882
// Move constant to the right side.

mlir/test/Dialect/Arith/canonicalize.mlir

+72
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,78 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
160160
return %res1, %res2 : i32, i32
161161
}
162162

163+
// CHECK-LABEL: @cmpiI1eq
164+
// CHECK-SAME: (%[[ARG:.*]]: i1)
165+
// CHECK: return %[[ARG]]
166+
func.func @cmpiI1eq(%arg0: i1) -> i1 {
167+
%one = arith.constant 1 : i1
168+
%res = arith.cmpi eq, %arg0, %one : i1
169+
return %res : i1
170+
}
171+
172+
// CHECK-LABEL: @cmpiI1eqVec
173+
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
174+
// CHECK: return %[[ARG]]
175+
func.func @cmpiI1eqVec(%arg0: vector<4xi1>) -> vector<4xi1> {
176+
%one = arith.constant dense<1> : vector<4xi1>
177+
%res = arith.cmpi eq, %arg0, %one : vector<4xi1>
178+
return %res : vector<4xi1>
179+
}
180+
181+
// CHECK-LABEL: @cmpiI1ne
182+
// CHECK-SAME: (%[[ARG:.*]]: i1)
183+
// CHECK: return %[[ARG]]
184+
func.func @cmpiI1ne(%arg0: i1) -> i1 {
185+
%zero = arith.constant 0 : i1
186+
%res = arith.cmpi ne, %arg0, %zero : i1
187+
return %res : i1
188+
}
189+
190+
// CHECK-LABEL: @cmpiI1neVec
191+
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
192+
// CHECK: return %[[ARG]]
193+
func.func @cmpiI1neVec(%arg0: vector<4xi1>) -> vector<4xi1> {
194+
%zero = arith.constant dense<0> : vector<4xi1>
195+
%res = arith.cmpi ne, %arg0, %zero : vector<4xi1>
196+
return %res : vector<4xi1>
197+
}
198+
199+
// CHECK-LABEL: @cmpiI1eqLhs
200+
// CHECK-SAME: (%[[ARG:.*]]: i1)
201+
// CHECK: return %[[ARG]]
202+
func.func @cmpiI1eqLhs(%arg0: i1) -> i1 {
203+
%one = arith.constant 1 : i1
204+
%res = arith.cmpi eq, %one, %arg0 : i1
205+
return %res : i1
206+
}
207+
208+
// CHECK-LABEL: @cmpiI1eqVecLhs
209+
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
210+
// CHECK: return %[[ARG]]
211+
func.func @cmpiI1eqVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
212+
%one = arith.constant dense<1> : vector<4xi1>
213+
%res = arith.cmpi eq, %one, %arg0 : vector<4xi1>
214+
return %res : vector<4xi1>
215+
}
216+
217+
// CHECK-LABEL: @cmpiI1neLhs
218+
// CHECK-SAME: (%[[ARG:.*]]: i1)
219+
// CHECK: return %[[ARG]]
220+
func.func @cmpiI1neLhs(%arg0: i1) -> i1 {
221+
%zero = arith.constant 0 : i1
222+
%res = arith.cmpi ne, %zero, %arg0 : i1
223+
return %res : i1
224+
}
225+
226+
// CHECK-LABEL: @cmpiI1neVecLhs
227+
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
228+
// CHECK: return %[[ARG]]
229+
func.func @cmpiI1neVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
230+
%zero = arith.constant dense<0> : vector<4xi1>
231+
%res = arith.cmpi ne, %zero, %arg0 : vector<4xi1>
232+
return %res : vector<4xi1>
233+
}
234+
163235
// Test case: Folding of comparisons with equal operands.
164236
// CHECK-LABEL: @cmpi_equal_operands
165237
// CHECK-DAG: %[[T:.*]] = arith.constant true

0 commit comments

Comments
 (0)