-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][arith] Fold arith.cmpi eq, %val, %one : i1
-> %val
and arith.cmpi ne, %val, %zero : i1 -> %val
#124436
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesFull diff: https://github.com/llvm/llvm-project/pull/124436.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7ca104691e6df6..75d59ba8c1a108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1865,6 +1865,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
+
+ // arith.cmpi ne, %val, %zero : i1 -> %val
+ if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+ getPredicate() == arith::CmpIPredicate::ne)
+ return getLhs();
+ }
+
+ if (matchPattern(adaptor.getRhs(), m_One())) {
+ // arith.cmpi eq, %val, %one : i1 -> %val
+ if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+ getPredicate() == arith::CmpIPredicate::eq)
+ return getLhs();
}
// Move constant to the right side.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 522711b08f289d..3a16ee3d4f8fde 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -160,6 +160,78 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
return %res1, %res2 : i32, i32
}
+// CHECK-LABEL: @cmpiI1eq
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eq(%arg0: i1) -> i1 {
+ %one = arith.constant 1 : i1
+ %res = arith.cmpi eq, %arg0, %one : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVec
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %one = arith.constant dense<1> : vector<4xi1>
+ %res = arith.cmpi eq, %arg0, %one : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1ne
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1ne(%arg0: i1) -> i1 {
+ %zero = arith.constant 0 : i1
+ %res = arith.cmpi ne, %arg0, %zero : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVec
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %zero = arith.constant dense<0> : vector<4xi1>
+ %res = arith.cmpi ne, %arg0, %zero : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1eqLhs
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqLhs(%arg0: i1) -> i1 {
+ %one = arith.constant 1 : i1
+ %res = arith.cmpi eq, %one, %arg0 : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVecLhs
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %one = arith.constant dense<1> : vector<4xi1>
+ %res = arith.cmpi eq, %one, %arg0 : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1neLhs
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neLhs(%arg0: i1) -> i1 {
+ %zero = arith.constant 0 : i1
+ %res = arith.cmpi ne, %zero, %arg0 : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVecLhs
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %zero = arith.constant dense<0> : vector<4xi1>
+ %res = arith.cmpi ne, %zero, %arg0 : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
|
@llvm/pr-subscribers-mlir-arith Author: Ivan Butygin (Hardcode84) ChangesFull diff: https://github.com/llvm/llvm-project/pull/124436.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7ca104691e6df6..75d59ba8c1a108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1865,6 +1865,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
+
+ // arith.cmpi ne, %val, %zero : i1 -> %val
+ if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+ getPredicate() == arith::CmpIPredicate::ne)
+ return getLhs();
+ }
+
+ if (matchPattern(adaptor.getRhs(), m_One())) {
+ // arith.cmpi eq, %val, %one : i1 -> %val
+ if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+ getPredicate() == arith::CmpIPredicate::eq)
+ return getLhs();
}
// Move constant to the right side.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 522711b08f289d..3a16ee3d4f8fde 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -160,6 +160,78 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
return %res1, %res2 : i32, i32
}
+// CHECK-LABEL: @cmpiI1eq
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eq(%arg0: i1) -> i1 {
+ %one = arith.constant 1 : i1
+ %res = arith.cmpi eq, %arg0, %one : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVec
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %one = arith.constant dense<1> : vector<4xi1>
+ %res = arith.cmpi eq, %arg0, %one : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1ne
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1ne(%arg0: i1) -> i1 {
+ %zero = arith.constant 0 : i1
+ %res = arith.cmpi ne, %arg0, %zero : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVec
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %zero = arith.constant dense<0> : vector<4xi1>
+ %res = arith.cmpi ne, %arg0, %zero : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1eqLhs
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqLhs(%arg0: i1) -> i1 {
+ %one = arith.constant 1 : i1
+ %res = arith.cmpi eq, %one, %arg0 : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVecLhs
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %one = arith.constant dense<1> : vector<4xi1>
+ %res = arith.cmpi eq, %one, %arg0 : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1neLhs
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neLhs(%arg0: i1) -> i1 {
+ %zero = arith.constant 0 : i1
+ %res = arith.cmpi ne, %zero, %arg0 : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVecLhs
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %zero = arith.constant dense<0> : vector<4xi1>
+ %res = arith.cmpi ne, %zero, %arg0 : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
|
1d95c2b
to
7c29918
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but it would be great if you could also attach an alive link
done |
https://alive2.llvm.org/ce/z/dNZMdC