Skip to content

Commit 766f50c

Browse files
authored
Enable AIRFuseChannel to fuse channels temporally, if there's no existing for loop (Xilinx#597)
* Fixup bug where channel fusion by for loop only checks puts but not gets * Fixup ci tests which were wrong in the first place * Enable temporal channel fusion without for loop in place * Clang format * Test
1 parent e8f5f8c commit 766f50c

File tree

4 files changed

+304
-75
lines changed

4 files changed

+304
-75
lines changed

mlir/lib/Transform/AIRDependencyScheduleOpt.cpp

Lines changed: 156 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,23 +3009,49 @@ class AIRFuseChannels
30093009
init_options();
30103010
if (channelOps.empty())
30113011
return;
3012+
// Rename symbols
3013+
// TODO: make this greedy
3014+
auto renameSymbols =
3015+
[](std::vector<air::ChannelOp> &channelOps,
3016+
std::map<air::ChannelOp, air::ChannelOp> chan_merge_map) {
3017+
for (unsigned i = 0; i < channelOps.size(); i++) {
3018+
for (auto chanKey : channelOps) {
3019+
if (!chan_merge_map.count(chanKey))
3020+
continue;
3021+
auto error = mlir::SymbolTable::replaceAllSymbolUses(
3022+
chanKey.getOperation(),
3023+
mlir::SymbolTable::getSymbolName(chan_merge_map[chanKey]),
3024+
chanKey->getParentOfType<ModuleOp>());
3025+
// FIXME: what if this fails?
3026+
(void)error;
3027+
}
3028+
}
3029+
};
30123030
std::map<air::ChannelOp, air::ChannelOp> chan_merge_map;
30133031
for (unsigned i = 0; i < channelOps.size() - 1; i++) {
30143032
for (unsigned j = i + 1; j < channelOps.size(); j++) {
30153033
std::tuple<bool, std::string> checkScfForMergeableRes =
3016-
checkIfScfForMergeable(channelOps[i], channelOps[j]);
3034+
checkIfTemporalMergeable(channelOps[i], channelOps[j]);
30173035
if (!std::get<0>(checkScfForMergeableRes))
30183036
continue;
3019-
// Fuse air.channels by scf.for loop unpeeling, i.e. recovering any
3020-
// missing zeroth iterations in scf.for loops.
30213037
air::ChannelOp chanA = channelOps[i];
30223038
air::ChannelOp chanB = channelOps[j];
3023-
sortChannelsByLoopNests(chanA, chanB);
3024-
mergeChannelOpsByScfFor(chanA, chanB,
3025-
std::get<1>(checkScfForMergeableRes));
3026-
chan_merge_map[chanB] = chanA;
3039+
if (std::get<1>(checkScfForMergeableRes) == "LB" ||
3040+
std::get<1>(checkScfForMergeableRes) == "UB") {
3041+
// Fuse air.channels by scf.for loop unpeeling, i.e. recovering any
3042+
// missing zeroth/last iterations in scf.for loops.
3043+
sortChannelsByLoopNests(chanA, chanB);
3044+
mergeChannelOpsTemporally(chanA, chanB,
3045+
std::get<1>(checkScfForMergeableRes));
3046+
chan_merge_map[chanB] = chanA;
3047+
} else if (std::get<1>(checkScfForMergeableRes) == "NFL") {
3048+
// Fuse air.channels temporally, if there isn't any for loop to fuse
3049+
// into.
3050+
chan_merge_map[chanB] = chanA;
3051+
}
30273052
}
30283053
}
3054+
renameSymbols(channelOps, chan_merge_map);
30293055
if (!targetMemorySpaces.empty()) {
30303056
for (unsigned i = 0; i < channelOps.size() - 1; i++) {
30313057
for (unsigned j = i + 1; j < channelOps.size(); j++) {
@@ -3037,20 +3063,7 @@ class AIRFuseChannels
30373063
}
30383064
}
30393065
}
3040-
// Rename symbols
3041-
// TODO: make this greedy
3042-
for (unsigned i = 0; i < channelOps.size(); i++) {
3043-
for (auto chanKey : channelOps) {
3044-
if (!chan_merge_map.count(chanKey))
3045-
continue;
3046-
auto error = mlir::SymbolTable::replaceAllSymbolUses(
3047-
chanKey.getOperation(),
3048-
mlir::SymbolTable::getSymbolName(chan_merge_map[chanKey]),
3049-
chanKey->getParentOfType<ModuleOp>());
3050-
// FIXME: what if this fails?
3051-
(void)error;
3052-
}
3053-
}
3066+
renameSymbols(channelOps, chan_merge_map);
30543067
}
30553068

30563069
void runOnOperation() override {
@@ -3183,8 +3196,8 @@ class AIRFuseChannels
31833196
}
31843197

31853198
std::tuple<bool, std::string>
3186-
checkIfScfForMergeableImpl(std::vector<Operation *> a_loop_nest,
3187-
std::vector<Operation *> b_loop_nest) {
3199+
checkIfTemporalMergeableByScfForImpl(std::vector<Block *> a_loop_nest,
3200+
std::vector<Block *> b_loop_nest) {
31883201
std::tuple<bool, std::string> notMergeable = {false, ""};
31893202
std::tuple<bool, std::string> mergeableToLB = {true, "LB"};
31903203
std::tuple<bool, std::string> mergeableToUB = {true, "UB"};
@@ -3194,20 +3207,22 @@ class AIRFuseChannels
31943207
std::max(a_loop_nest.size(), b_loop_nest.size());
31953208
// Skip over the unequal scf.for loop, and check equality for the rest of
31963209
// the loops first.
3197-
unsigned outermostScfFor = -1;
3210+
int outermostScfFor = -1;
31983211
for (unsigned i = 0; i < max_loop_nest_count; i++) {
31993212
unsigned scfForCount = 0;
3200-
if ((i < a_loop_nest.size()) && isa<scf::ForOp>(a_loop_nest[i]))
3213+
if ((i < a_loop_nest.size()) &&
3214+
isa<scf::ForOp>(a_loop_nest[i]->getParentOp()))
32013215
scfForCount++;
3202-
if ((i < b_loop_nest.size()) && isa<scf::ForOp>(b_loop_nest[i]))
3216+
if ((i < b_loop_nest.size()) &&
3217+
isa<scf::ForOp>(b_loop_nest[i]->getParentOp()))
32033218
scfForCount++;
32043219
if (scfForCount == 1)
3205-
outermostScfFor = i;
3220+
outermostScfFor = (int)i;
32063221
}
32073222
if (outermostScfFor < 0)
32083223
return notMergeable;
32093224
SmallVector<unsigned> controlLoopIndices;
3210-
for (unsigned i = 0; i < max_loop_nest_count; i++)
3225+
for (int i = 0; i < (int)max_loop_nest_count; i++)
32113226
if (i != outermostScfFor)
32123227
controlLoopIndices.push_back(i);
32133228
// TODO: Assuming b_loop_nest is before a_loop_nest. Always true? TODO:
@@ -3220,7 +3235,8 @@ class AIRFuseChannels
32203235
}
32213236
// Check if the skipped scf.for loop has LB >= 1. This is a sign of
32223237
// peeling, indicating opportunity of merge by unpeeling into LB.
3223-
auto outerMostScfFor = dyn_cast<scf::ForOp>(a_loop_nest[outermostScfFor]);
3238+
auto outerMostScfFor =
3239+
dyn_cast<scf::ForOp>(a_loop_nest[outermostScfFor]->getParentOp());
32243240
assert(outerMostScfFor);
32253241
if (auto constLB = getConstantIntValue(outerMostScfFor.getLowerBound()))
32263242
if (*constLB < 1)
@@ -3232,12 +3248,31 @@ class AIRFuseChannels
32323248
return notMergeable;
32333249
}
32343250
// Merge by unpeeling into UB.
3235-
auto outerMostScfFor = dyn_cast<scf::ForOp>(b_loop_nest[outermostScfFor]);
3251+
auto outerMostScfFor =
3252+
dyn_cast<scf::ForOp>(b_loop_nest[outermostScfFor]->getParentOp());
32363253
assert(outerMostScfFor);
32373254
return mergeableToUB;
32383255
}
32393256
return mergeableToLB;
32403257
}
3258+
std::tuple<bool, std::string>
3259+
checkIfTemporalMergeableImpl(std::vector<Block *> a_loop_nest,
3260+
std::vector<Block *> b_loop_nest) {
3261+
std::tuple<bool, std::string> notMergeable = {false, ""};
3262+
std::tuple<bool, std::string> mergeableToLB = {true, "LB"};
3263+
std::tuple<bool, std::string> mergeableToUB = {true, "UB"};
3264+
std::tuple<bool, std::string> mergeableNoForLoop = {true, "NFL"};
3265+
if (std::abs((int)a_loop_nest.size() - (int)b_loop_nest.size()) == 1)
3266+
return checkIfTemporalMergeableByScfForImpl(a_loop_nest, b_loop_nest);
3267+
else if (a_loop_nest.size() != b_loop_nest.size())
3268+
return notMergeable;
3269+
3270+
for (unsigned i = 0; i < a_loop_nest.size(); i++) {
3271+
if (!areEquivalentControlLoops(a_loop_nest[i], b_loop_nest[i]))
3272+
return notMergeable;
3273+
}
3274+
return mergeableNoForLoop;
3275+
}
32413276
Value getHierOperandFromHierBlockArgument(BlockArgument arg) {
32423277
if (!arg)
32433278
return nullptr;
@@ -3284,9 +3319,10 @@ class AIRFuseChannels
32843319
}
32853320
// Check of two air.channels are mergeable in time, by fusing into a shared
32863321
// scf.for loop. Returns a tuple of bool of whether mergeable, and string of
3287-
// fusing into for loop lower bound (LB) or upper bound (UB).
3288-
std::tuple<bool, std::string> checkIfScfForMergeable(air::ChannelOp chan_a,
3289-
air::ChannelOp chan_b) {
3322+
// fusing into for loop lower bound (LB) or upper bound (UB), or fuse with no
3323+
// for loop (NFL).
3324+
std::tuple<bool, std::string>
3325+
checkIfTemporalMergeable(air::ChannelOp chan_a, air::ChannelOp chan_b) {
32903326
std::vector<air::ChannelPutOp> a_puts =
32913327
getChannelPutOpThroughSymbol(chan_a);
32923328
std::vector<air::ChannelPutOp> b_puts =
@@ -3297,6 +3333,8 @@ class AIRFuseChannels
32973333
getChannelGetOpThroughSymbol(chan_b);
32983334
std::tuple<bool, std::string> notMergeable = {false, ""};
32993335
std::tuple<bool, std::string> mergeableToLB = {true, "LB"};
3336+
std::tuple<bool, std::string> mergeableToUB = {true, "UB"};
3337+
std::tuple<bool, std::string> mergeableNoForLoop = {true, "NFL"};
33003338
if (a_puts.size() != b_puts.size())
33013339
return notMergeable;
33023340
if (a_puts.size() != 1)
@@ -3356,35 +3394,73 @@ class AIRFuseChannels
33563394
(!areTheSameSSAValueLists(aSizes, bSizes)) ||
33573395
(!areTheSameSSAValueLists(aStrides, bStrides)))
33583396
return notMergeable;
3397+
std::vector<std::tuple<bool, std::string>> putResults;
33593398
for (unsigned i = 0; i < a_puts.size(); i++) {
33603399
auto a_put_loop_nest = getParentLoopNest(a_puts[i].getOperation());
33613400
auto b_put_loop_nest = getParentLoopNest(b_puts[i].getOperation());
3362-
return checkIfScfForMergeableImpl(a_put_loop_nest, b_put_loop_nest);
3401+
putResults.push_back(
3402+
checkIfTemporalMergeableImpl(a_put_loop_nest, b_put_loop_nest));
33633403
}
3404+
std::vector<std::tuple<bool, std::string>> getResults;
33643405
for (unsigned i = 0; i < a_gets.size(); i++) {
33653406
auto a_get_loop_nest = getParentLoopNest(a_gets[i].getOperation());
33663407
auto b_get_loop_nest = getParentLoopNest(b_gets[i].getOperation());
3367-
return checkIfScfForMergeableImpl(a_get_loop_nest, b_get_loop_nest);
3368-
}
3369-
return mergeableToLB;
3408+
getResults.push_back(
3409+
checkIfTemporalMergeableImpl(a_get_loop_nest, b_get_loop_nest));
3410+
}
3411+
bool overallUBMergeable = true;
3412+
bool overallLBMergeable = true;
3413+
bool overallNFLMergeable = true;
3414+
for (auto putRes : putResults) {
3415+
if (!std::get<0>(putRes))
3416+
return notMergeable;
3417+
overallUBMergeable &= (std::get<1>(putRes) == "UB");
3418+
overallLBMergeable &= (std::get<1>(putRes) == "LB");
3419+
overallNFLMergeable &= (std::get<1>(putRes) == "NFL");
3420+
}
3421+
for (auto getRes : getResults) {
3422+
if (!std::get<0>(getRes))
3423+
return notMergeable;
3424+
overallUBMergeable &= (std::get<1>(getRes) == "UB");
3425+
overallLBMergeable &= (std::get<1>(getRes) == "LB");
3426+
overallNFLMergeable &= (std::get<1>(getRes) == "NFL");
3427+
}
3428+
if (overallNFLMergeable)
3429+
return mergeableNoForLoop;
3430+
else if (overallLBMergeable)
3431+
return mergeableToLB;
3432+
else if (overallUBMergeable)
3433+
return mergeableToUB;
3434+
return notMergeable;
33703435
}
3371-
std::vector<Operation *> getParentLoopNest(Operation *op) {
3372-
std::vector<Operation *> parent_loop_nest;
3436+
std::vector<Block *> getParentLoopNest(Operation *op) {
3437+
std::vector<Block *> parent_loop_nest;
33733438
Operation *parent = op;
33743439
while (parent) {
3375-
if (isa<scf::ForOp>(parent))
3376-
parent_loop_nest.push_back(parent);
3377-
else if (isa<scf::ParallelOp>(parent))
3378-
parent_loop_nest.push_back(parent);
3379-
else if (isa<air::HierarchyInterface>(parent))
3380-
parent_loop_nest.push_back(parent);
3381-
else if (isa<affine::AffineIfOp>(parent))
3382-
parent_loop_nest.push_back(parent);
3440+
if (auto forOp = dyn_cast<scf::ForOp>(parent))
3441+
parent_loop_nest.push_back(forOp.getBody());
3442+
else if (auto parOp = dyn_cast<scf::ParallelOp>(parent))
3443+
parent_loop_nest.push_back(parOp.getBody());
3444+
else if (auto hierOp = dyn_cast<air::HierarchyInterface>(parent))
3445+
parent_loop_nest.push_back(&hierOp->getRegion(0).front());
3446+
else if (auto aifOp = dyn_cast<affine::AffineIfOp>(parent)) {
3447+
if (aifOp.getThenBlock()->findAncestorOpInBlock(*op))
3448+
parent_loop_nest.push_back(aifOp.getThenBlock());
3449+
else if (aifOp.hasElse() &&
3450+
aifOp.getElseBlock()->findAncestorOpInBlock(*op))
3451+
parent_loop_nest.push_back(aifOp.getElseBlock());
3452+
}
33833453
parent = parent->getParentOp();
33843454
}
33853455
return parent_loop_nest;
33863456
}
3387-
bool areEquivalentControlLoops(Operation *a, Operation *b) {
3457+
bool areEquivalentControlLoops(Block *aBlock, Block *bBlock) {
3458+
Operation *a = aBlock->getParentOp();
3459+
Operation *b = bBlock->getParentOp();
3460+
if (!a)
3461+
return false;
3462+
if (!b)
3463+
return false;
33883464
if (isa<scf::ForOp>(a) && isa<scf::ForOp>(b)) {
33893465
auto a_for = dyn_cast<scf::ForOp>(a);
33903466
auto b_for = dyn_cast<scf::ForOp>(b);
@@ -3437,11 +3513,27 @@ class AIRFuseChannels
34373513
}
34383514
return true;
34393515
} else if (isa<air::HierarchyInterface>(a) &&
3440-
isa<air::HierarchyInterface>(a)) {
3516+
isa<air::HierarchyInterface>(b)) {
34413517
if (a == b)
34423518
return true;
3443-
} else if (isa<affine::AffineIfOp>(a) || isa<affine::AffineIfOp>(b))
3444-
return false;
3519+
auto aHierSym =
3520+
a->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
3521+
auto bHierSym =
3522+
b->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
3523+
if (aHierSym && bHierSym && aHierSym.str() == bHierSym.str())
3524+
return true;
3525+
} else if (isa<affine::AffineIfOp>(a) || isa<affine::AffineIfOp>(b)) {
3526+
if (a == b)
3527+
return false; // Sharing the same affine.if means spatially parallel
3528+
// ops. Cannot merge by for loop (i.e. in time).
3529+
auto aIf = dyn_cast<affine::AffineIfOp>(a);
3530+
auto bIf = dyn_cast<affine::AffineIfOp>(b);
3531+
if (aBlock == aIf.getThenBlock() && bBlock == bIf.getThenBlock())
3532+
return true;
3533+
if (aIf.hasElse() && bIf.hasElse() && aBlock == aIf.getElseBlock() &&
3534+
bBlock == bIf.getElseBlock())
3535+
return true;
3536+
}
34453537
return false;
34463538
}
34473539
void mergeChannelOps(air::ChannelInterface a, air::ChannelInterface b) {
@@ -3459,8 +3551,9 @@ class AIRFuseChannels
34593551
async_a.addAsyncDependency(
34603552
dyn_cast<air::AsyncOpInterface>(new_b).getAsyncToken());
34613553
}
3462-
void mergeChannelOpsByScfFor(air::ChannelInterface a, air::ChannelInterface b,
3463-
std::string mergeByLBOrUB) {
3554+
void mergeChannelOpsTemporally(air::ChannelInterface a,
3555+
air::ChannelInterface b,
3556+
std::string mergeByLBOrUB) {
34643557
scf::ForOp parentForOp = a->getParentOfType<scf::ForOp>();
34653558
while (parentForOp && parentForOp->getParentOfType<scf::ForOp>()) {
34663559
parentForOp = parentForOp->getParentOfType<scf::ForOp>();
@@ -3498,8 +3591,8 @@ class AIRFuseChannels
34983591
mergeChannelOps(a_gets[i], b_gets[i]);
34993592
}
35003593
}
3501-
void mergeChannelOpsByScfFor(air::ChannelOp chan_a, air::ChannelOp chan_b,
3502-
std::string mergeByLBOrUB) {
3594+
void mergeChannelOpsTemporally(air::ChannelOp chan_a, air::ChannelOp chan_b,
3595+
std::string mergeByLBOrUB) {
35033596
std::vector<air::ChannelPutOp> a_puts =
35043597
getChannelPutOpThroughSymbol(chan_a);
35053598
std::vector<air::ChannelPutOp> b_puts =
@@ -3509,10 +3602,10 @@ class AIRFuseChannels
35093602
std::vector<air::ChannelGetOp> b_gets =
35103603
getChannelGetOpThroughSymbol(chan_b);
35113604
if (!b_puts[0]->getParentOfType<air::HerdOp>()) {
3512-
mergeChannelOpsByScfFor(a_puts[0], b_puts[0], mergeByLBOrUB);
3605+
mergeChannelOpsTemporally(a_puts[0], b_puts[0], mergeByLBOrUB);
35133606
}
35143607
if (!b_gets[0]->getParentOfType<air::HerdOp>()) {
3515-
mergeChannelOpsByScfFor(a_gets[0], b_gets[0], mergeByLBOrUB);
3608+
mergeChannelOpsTemporally(a_gets[0], b_gets[0], mergeByLBOrUB);
35163609
}
35173610
}
35183611
Operation *cloneOpAndOperands(OpBuilder builder, IRMapping remap,
@@ -3550,16 +3643,18 @@ class AIRFuseChannels
35503643
if (a_loop_nest.size() != b_loop_nest.size())
35513644
return;
35523645
for (unsigned i = 0; i < a_loop_nest.size(); i++) {
3553-
if (auto a_for = dyn_cast<scf::ForOp>(a_loop_nest[i])) {
3554-
if (auto b_for = dyn_cast<scf::ForOp>(b_loop_nest[i])) {
3646+
if (auto a_for = dyn_cast<scf::ForOp>(a_loop_nest[i]->getParentOp())) {
3647+
if (auto b_for = dyn_cast<scf::ForOp>(b_loop_nest[i]->getParentOp())) {
35553648
for (unsigned j = 0; j < a_for.getBody()->getNumArguments(); j++) {
35563649
remap.map(b_for.getBody()->getArgument(j),
35573650
a_for.getBody()->getArgument(j));
35583651
}
35593652
}
35603653
}
3561-
if (auto a_par = dyn_cast<scf::ParallelOp>(a_loop_nest[i])) {
3562-
if (auto b_par = dyn_cast<scf::ParallelOp>(b_loop_nest[i])) {
3654+
if (auto a_par =
3655+
dyn_cast<scf::ParallelOp>(a_loop_nest[i]->getParentOp())) {
3656+
if (auto b_par =
3657+
dyn_cast<scf::ParallelOp>(b_loop_nest[i]->getParentOp())) {
35633658
for (unsigned j = 0; j < a_par.getBody()->getNumArguments(); j++) {
35643659
remap.map(b_par.getBody()->getArgument(j),
35653660
a_par.getBody()->getArgument(j));

mlir/test/Conversion/AIRToAIE/async_gemm_w_pingpong_to_locks_aie2.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
// CHECK: %[[VAL_7:.*]] = aie.tile(6, 4)
1818
// CHECK-COUNT-8: aie.lock(%[[VAL_3]], {{.*}})
1919
// CHECK-COUNT-2: aie.lock(%[[VAL_2]], {{.*}})
20-
// CHECK-COUNT-8: aie.lock(%[[VAL_4]], {{.*}})
21-
// CHECK-COUNT-8: aie.lock(%[[VAL_5]], {{.*}})
22-
// CHECK-COUNT-8: aie.lock(%[[VAL_6]], {{.*}})
23-
// CHECK-COUNT-8: aie.lock(%[[VAL_7]], {{.*}})
20+
// CHECK-COUNT-6: aie.lock(%[[VAL_4]], {{.*}})
21+
// CHECK-COUNT-6: aie.lock(%[[VAL_5]], {{.*}})
22+
// CHECK-COUNT-6: aie.lock(%[[VAL_6]], {{.*}})
23+
// CHECK-COUNT-6: aie.lock(%[[VAL_7]], {{.*}})
2424
// CHECK: aie.buffer(%[[VAL_2]]) {{{.*}}} : memref<64x64xi32, 1>
2525
// CHECK: aie.buffer(%[[VAL_3]]) {{{.*}}} : memref<64x128xi32, 1>
2626
// CHECK: aie.buffer(%[[VAL_3]]) {{{.*}}} : memref<128x64xi32, 1>

0 commit comments

Comments
 (0)