Skip to content

Commit bd3af49

Browse files
authored
[Integrate] Cherry-pick llvm/llvm-project@41f6566 (#22470)
Still carrying two reverts as #22366. New cherry-picks: - [MLIR] Revamp RegionBranchOpInterface (llvm/llvm-project@41f6566) - Corresponding torch-mlir fix (llvm/torch-mlir#4358) Also, IREE fixes on `RegionBranchPoint` and `RegionSuccessor`. ci-extra: test_torch, windows_x64_msvc Signed-off-by: Yu-Zhewen <[email protected]>
1 parent b35735f commit bd3af49

File tree

5 files changed

+25
-20
lines changed

5 files changed

+25
-20
lines changed

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3085,8 +3085,8 @@ std::pair<unsigned, unsigned> AsyncExecuteOp::getTiedResultsIndexAndLength() {
30853085
}
30863086

30873087
OperandRange
3088-
AsyncExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3089-
assert(point.getRegionOrNull() == &getBody() && "invalid region index");
3088+
AsyncExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3089+
assert(successor.getSuccessor() == &getBody() && "invalid region index");
30903090
return getResourceOperands();
30913091
}
30923092

@@ -3096,7 +3096,7 @@ void AsyncExecuteOp::getSuccessorRegions(
30963096
// return the correct RegionSuccessor purely based on the index being None or
30973097
// 0.
30983098
if (!point.isParent()) {
3099-
regions.push_back(RegionSuccessor(getResults()));
3099+
regions.push_back(RegionSuccessor(getOperation(), getResults()));
31003100
} else {
31013101
regions.push_back(RegionSuccessor(&getBody(), getBody().getArguments()));
31023102
}
@@ -3240,8 +3240,8 @@ LogicalResult AsyncConcurrentOp::verify() {
32403240
}
32413241

32423242
OperandRange
3243-
AsyncConcurrentOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3244-
assert(point == &getBody() && "invalid region index");
3243+
AsyncConcurrentOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3244+
assert(successor.getSuccessor() == &getBody() && "invalid region index");
32453245
return getResourceOperands();
32463246
}
32473247

@@ -3251,7 +3251,7 @@ void AsyncConcurrentOp::getSuccessorRegions(
32513251
// return the correct RegionSuccessor purely based on the index being None or
32523252
// 0.
32533253
if (!point.isParent()) {
3254-
regions.push_back(RegionSuccessor(getResults()));
3254+
regions.push_back(RegionSuccessor(getOperation(), getResults()));
32553255
} else {
32563256
regions.push_back(RegionSuccessor(&getBody(), getBody().getArguments()));
32573257
}
@@ -4049,8 +4049,9 @@ LogicalResult CmdExecuteOp::verify() {
40494049
return success();
40504050
}
40514051

4052-
OperandRange CmdExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
4053-
assert(point == &getBody() && "invalid region index");
4052+
OperandRange
4053+
CmdExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) {
4054+
assert(successor.getSuccessor() == &getBody() && "invalid region index");
40544055
return getResourceOperands();
40554056
}
40564057

@@ -4060,7 +4061,8 @@ void CmdExecuteOp::getSuccessorRegions(
40604061
// return the correct RegionSuccessor purely based on the index being None or
40614062
// 0.
40624063
if (!point.isParent()) {
4063-
regions.push_back(RegionSuccessor({}));
4064+
regions.push_back(
4065+
RegionSuccessor(getOperation(), Operation::result_range(nullptr, 0)));
40644066
} else {
40654067
regions.push_back(RegionSuccessor(&getBody(), getBody().getArguments()));
40664068
}
@@ -4130,7 +4132,8 @@ void CmdSerialOp::getSuccessorRegions(
41304132
// return the correct RegionSuccessor purely based on the index being None or
41314133
// 0.
41324134
if (!point.isParent()) {
4133-
regions.push_back(RegionSuccessor({}));
4135+
regions.push_back(
4136+
RegionSuccessor(getOperation(), Operation::result_range(nullptr, 0)));
41344137
} else {
41354138
regions.push_back(RegionSuccessor(&getBody(), {}));
41364139
}
@@ -4155,7 +4158,8 @@ void CmdConcurrentOp::getSuccessorRegions(
41554158
// return the correct RegionSuccessor purely based on the index being None or
41564159
// 0.
41574160
if (!point.isParent()) {
4158-
regions.push_back(RegionSuccessor({}));
4161+
regions.push_back(
4162+
RegionSuccessor(getOperation(), Operation::result_range(nullptr, 0)));
41594163
} else {
41604164
regions.push_back(RegionSuccessor(&getBody(), {}));
41614165
}
@@ -4426,7 +4430,7 @@ LogicalResult DispatchWorkgroupSizeOp::verify() {
44264430
//===----------------------------------------------------------------------===//
44274431

44284432
MutableOperandRange
4429-
YieldOp::getMutableSuccessorOperands(RegionBranchPoint point) {
4433+
YieldOp::getMutableSuccessorOperands(RegionSuccessor successor) {
44304434
return getResourceOperandsMutable();
44314435
}
44324436

compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,7 @@ static IREE::Stream::AffinityAttr findLocalValueAffinity(Value value) {
11611161
auto terminatorOp =
11621162
cast<RegionBranchTerminatorOpInterface>(block.getTerminator());
11631163
value = terminatorOp.getSuccessorOperands(
1164-
RegionBranchPoint::parent())[resultIndex];
1164+
RegionSuccessor(definingOp, definingOp->getResults()))[resultIndex];
11651165
} else if (auto tiedOp =
11661166
dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
11671167
// If the producer is tied then try to get the operand.

compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,8 @@ TraversalResult Explorer::walkReturnOperands(Operation *parentOp,
599599
return walkReturnOps(parentOp, [&](Operation *returnOp) {
600600
if (auto terminatorOp =
601601
dyn_cast<RegionBranchTerminatorOpInterface>(returnOp)) {
602-
return fn(terminatorOp.getSuccessorOperands(RegionBranchPoint::parent()));
602+
return fn(terminatorOp.getSuccessorOperands(
603+
RegionSuccessor(parentOp, parentOp->getResults())));
603604
} else {
604605
return fn(returnOp->getOperands());
605606
}
@@ -992,8 +993,9 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn,
992993
// Move within/out-of a region.
993994
auto traverseRegionBranchOp = [&](RegionBranchTerminatorOpInterface branchOp,
994995
unsigned operandIdx) {
995-
auto successorOperands =
996-
branchOp.getSuccessorOperands(RegionBranchPoint::parent());
996+
Operation *parentOp = branchOp.getOperation()->getParentOp();
997+
auto successorOperands = branchOp.getSuccessorOperands(
998+
RegionSuccessor(parentOp, parentOp->getResults()));
997999
unsigned beginIdx = successorOperands.getBeginOperandIndex();
9981000
if (operandIdx < beginIdx ||
9991001
operandIdx >= beginIdx + successorOperands.size()) {
@@ -1003,8 +1005,7 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn,
10031005
<< operandIdx << "\n");
10041006
return TraversalResult::COMPLETE;
10051007
}
1006-
auto result = branchOp.getSuccessorOperands(
1007-
RegionBranchPoint::parent())[operandIdx - beginIdx];
1008+
auto result = successorOperands[operandIdx - beginIdx];
10081009
LLVM_DEBUG({
10091010
llvm::dbgs() << " + queuing region result ";
10101011
result.printAsOperand(llvm::dbgs(), asmState);

third_party/llvm-project

Submodule llvm-project updated 38 files

third_party/torch-mlir

0 commit comments

Comments
 (0)