@@ -3009,23 +3009,49 @@ class AIRFuseChannels
3009
3009
init_options ();
3010
3010
if (channelOps.empty ())
3011
3011
return ;
3012
+ // Rename symbols
3013
+ // TODO: make this greedy
3014
+ auto renameSymbols =
3015
+ [](std::vector<air::ChannelOp> &channelOps,
3016
+ std::map<air::ChannelOp, air::ChannelOp> chan_merge_map) {
3017
+ for (unsigned i = 0 ; i < channelOps.size (); i++) {
3018
+ for (auto chanKey : channelOps) {
3019
+ if (!chan_merge_map.count (chanKey))
3020
+ continue ;
3021
+ auto error = mlir::SymbolTable::replaceAllSymbolUses (
3022
+ chanKey.getOperation (),
3023
+ mlir::SymbolTable::getSymbolName (chan_merge_map[chanKey]),
3024
+ chanKey->getParentOfType <ModuleOp>());
3025
+ // FIXME: what if this fails?
3026
+ (void )error;
3027
+ }
3028
+ }
3029
+ };
3012
3030
std::map<air::ChannelOp, air::ChannelOp> chan_merge_map;
3013
3031
for (unsigned i = 0 ; i < channelOps.size () - 1 ; i++) {
3014
3032
for (unsigned j = i + 1 ; j < channelOps.size (); j++) {
3015
3033
std::tuple<bool , std::string> checkScfForMergeableRes =
3016
- checkIfScfForMergeable (channelOps[i], channelOps[j]);
3034
+ checkIfTemporalMergeable (channelOps[i], channelOps[j]);
3017
3035
if (!std::get<0 >(checkScfForMergeableRes))
3018
3036
continue ;
3019
- // Fuse air.channels by scf.for loop unpeeling, i.e. recovering any
3020
- // missing zeroth iterations in scf.for loops.
3021
3037
air::ChannelOp chanA = channelOps[i];
3022
3038
air::ChannelOp chanB = channelOps[j];
3023
- sortChannelsByLoopNests (chanA, chanB);
3024
- mergeChannelOpsByScfFor (chanA, chanB,
3025
- std::get<1 >(checkScfForMergeableRes));
3026
- chan_merge_map[chanB] = chanA;
3039
+ if (std::get<1 >(checkScfForMergeableRes) == " LB" ||
3040
+ std::get<1 >(checkScfForMergeableRes) == " UB" ) {
3041
+ // Fuse air.channels by scf.for loop unpeeling, i.e. recovering any
3042
+ // missing zeroth/last iterations in scf.for loops.
3043
+ sortChannelsByLoopNests (chanA, chanB);
3044
+ mergeChannelOpsTemporally (chanA, chanB,
3045
+ std::get<1 >(checkScfForMergeableRes));
3046
+ chan_merge_map[chanB] = chanA;
3047
+ } else if (std::get<1 >(checkScfForMergeableRes) == " NFL" ) {
3048
+ // Fuse air.channels temporally, if there isn't any for loop to fuse
3049
+ // into.
3050
+ chan_merge_map[chanB] = chanA;
3051
+ }
3027
3052
}
3028
3053
}
3054
+ renameSymbols (channelOps, chan_merge_map);
3029
3055
if (!targetMemorySpaces.empty ()) {
3030
3056
for (unsigned i = 0 ; i < channelOps.size () - 1 ; i++) {
3031
3057
for (unsigned j = i + 1 ; j < channelOps.size (); j++) {
@@ -3037,20 +3063,7 @@ class AIRFuseChannels
3037
3063
}
3038
3064
}
3039
3065
}
3040
- // Rename symbols
3041
- // TODO: make this greedy
3042
- for (unsigned i = 0 ; i < channelOps.size (); i++) {
3043
- for (auto chanKey : channelOps) {
3044
- if (!chan_merge_map.count (chanKey))
3045
- continue ;
3046
- auto error = mlir::SymbolTable::replaceAllSymbolUses (
3047
- chanKey.getOperation (),
3048
- mlir::SymbolTable::getSymbolName (chan_merge_map[chanKey]),
3049
- chanKey->getParentOfType <ModuleOp>());
3050
- // FIXME: what if this fails?
3051
- (void )error;
3052
- }
3053
- }
3066
+ renameSymbols (channelOps, chan_merge_map);
3054
3067
}
3055
3068
3056
3069
void runOnOperation () override {
@@ -3183,8 +3196,8 @@ class AIRFuseChannels
3183
3196
}
3184
3197
3185
3198
std::tuple<bool , std::string>
3186
- checkIfScfForMergeableImpl (std::vector<Operation *> a_loop_nest,
3187
- std::vector<Operation *> b_loop_nest) {
3199
+ checkIfTemporalMergeableByScfForImpl (std::vector<Block *> a_loop_nest,
3200
+ std::vector<Block *> b_loop_nest) {
3188
3201
std::tuple<bool , std::string> notMergeable = {false , " " };
3189
3202
std::tuple<bool , std::string> mergeableToLB = {true , " LB" };
3190
3203
std::tuple<bool , std::string> mergeableToUB = {true , " UB" };
@@ -3194,20 +3207,22 @@ class AIRFuseChannels
3194
3207
std::max (a_loop_nest.size (), b_loop_nest.size ());
3195
3208
// Skip over the unequal scf.for loop, and check equality for the rest of
3196
3209
// the loops first.
3197
- unsigned outermostScfFor = -1 ;
3210
+ int outermostScfFor = -1 ;
3198
3211
for (unsigned i = 0 ; i < max_loop_nest_count; i++) {
3199
3212
unsigned scfForCount = 0 ;
3200
- if ((i < a_loop_nest.size ()) && isa<scf::ForOp>(a_loop_nest[i]))
3213
+ if ((i < a_loop_nest.size ()) &&
3214
+ isa<scf::ForOp>(a_loop_nest[i]->getParentOp ()))
3201
3215
scfForCount++;
3202
- if ((i < b_loop_nest.size ()) && isa<scf::ForOp>(b_loop_nest[i]))
3216
+ if ((i < b_loop_nest.size ()) &&
3217
+ isa<scf::ForOp>(b_loop_nest[i]->getParentOp ()))
3203
3218
scfForCount++;
3204
3219
if (scfForCount == 1 )
3205
- outermostScfFor = i;
3220
+ outermostScfFor = ( int ) i;
3206
3221
}
3207
3222
if (outermostScfFor < 0 )
3208
3223
return notMergeable;
3209
3224
SmallVector<unsigned > controlLoopIndices;
3210
- for (unsigned i = 0 ; i < max_loop_nest_count; i++)
3225
+ for (int i = 0 ; i < ( int ) max_loop_nest_count; i++)
3211
3226
if (i != outermostScfFor)
3212
3227
controlLoopIndices.push_back (i);
3213
3228
// TODO: Assuming b_loop_nest is before a_loop_nest. Always true? TODO:
@@ -3220,7 +3235,8 @@ class AIRFuseChannels
3220
3235
}
3221
3236
// Check if the skipped scf.for loop has LB >= 1. This is a sign of
3222
3237
// peeling, indicating opportunity of merge by unpeeling into LB.
3223
- auto outerMostScfFor = dyn_cast<scf::ForOp>(a_loop_nest[outermostScfFor]);
3238
+ auto outerMostScfFor =
3239
+ dyn_cast<scf::ForOp>(a_loop_nest[outermostScfFor]->getParentOp ());
3224
3240
assert (outerMostScfFor);
3225
3241
if (auto constLB = getConstantIntValue (outerMostScfFor.getLowerBound ()))
3226
3242
if (*constLB < 1 )
@@ -3232,12 +3248,31 @@ class AIRFuseChannels
3232
3248
return notMergeable;
3233
3249
}
3234
3250
// Merge by unpeeling into UB.
3235
- auto outerMostScfFor = dyn_cast<scf::ForOp>(b_loop_nest[outermostScfFor]);
3251
+ auto outerMostScfFor =
3252
+ dyn_cast<scf::ForOp>(b_loop_nest[outermostScfFor]->getParentOp ());
3236
3253
assert (outerMostScfFor);
3237
3254
return mergeableToUB;
3238
3255
}
3239
3256
return mergeableToLB;
3240
3257
}
3258
+ std::tuple<bool , std::string>
3259
+ checkIfTemporalMergeableImpl (std::vector<Block *> a_loop_nest,
3260
+ std::vector<Block *> b_loop_nest) {
3261
+ std::tuple<bool , std::string> notMergeable = {false , " " };
3262
+ std::tuple<bool , std::string> mergeableToLB = {true , " LB" };
3263
+ std::tuple<bool , std::string> mergeableToUB = {true , " UB" };
3264
+ std::tuple<bool , std::string> mergeableNoForLoop = {true , " NFL" };
3265
+ if (std::abs ((int )a_loop_nest.size () - (int )b_loop_nest.size ()) == 1 )
3266
+ return checkIfTemporalMergeableByScfForImpl (a_loop_nest, b_loop_nest);
3267
+ else if (a_loop_nest.size () != b_loop_nest.size ())
3268
+ return notMergeable;
3269
+
3270
+ for (unsigned i = 0 ; i < a_loop_nest.size (); i++) {
3271
+ if (!areEquivalentControlLoops (a_loop_nest[i], b_loop_nest[i]))
3272
+ return notMergeable;
3273
+ }
3274
+ return mergeableNoForLoop;
3275
+ }
3241
3276
Value getHierOperandFromHierBlockArgument (BlockArgument arg) {
3242
3277
if (!arg)
3243
3278
return nullptr ;
@@ -3284,9 +3319,10 @@ class AIRFuseChannels
3284
3319
}
3285
3320
// Check of two air.channels are mergeable in time, by fusing into a shared
3286
3321
// scf.for loop. Returns a tuple of bool of whether mergeable, and string of
3287
- // fusing into for loop lower bound (LB) or upper bound (UB).
3288
- std::tuple<bool , std::string> checkIfScfForMergeable (air::ChannelOp chan_a,
3289
- air::ChannelOp chan_b) {
3322
+ // fusing into for loop lower bound (LB) or upper bound (UB), or fuse with no
3323
+ // for loop (NFL).
3324
+ std::tuple<bool , std::string>
3325
+ checkIfTemporalMergeable (air::ChannelOp chan_a, air::ChannelOp chan_b) {
3290
3326
std::vector<air::ChannelPutOp> a_puts =
3291
3327
getChannelPutOpThroughSymbol (chan_a);
3292
3328
std::vector<air::ChannelPutOp> b_puts =
@@ -3297,6 +3333,8 @@ class AIRFuseChannels
3297
3333
getChannelGetOpThroughSymbol (chan_b);
3298
3334
std::tuple<bool , std::string> notMergeable = {false , " " };
3299
3335
std::tuple<bool , std::string> mergeableToLB = {true , " LB" };
3336
+ std::tuple<bool , std::string> mergeableToUB = {true , " UB" };
3337
+ std::tuple<bool , std::string> mergeableNoForLoop = {true , " NFL" };
3300
3338
if (a_puts.size () != b_puts.size ())
3301
3339
return notMergeable;
3302
3340
if (a_puts.size () != 1 )
@@ -3356,35 +3394,73 @@ class AIRFuseChannels
3356
3394
(!areTheSameSSAValueLists (aSizes, bSizes)) ||
3357
3395
(!areTheSameSSAValueLists (aStrides, bStrides)))
3358
3396
return notMergeable;
3397
+ std::vector<std::tuple<bool , std::string>> putResults;
3359
3398
for (unsigned i = 0 ; i < a_puts.size (); i++) {
3360
3399
auto a_put_loop_nest = getParentLoopNest (a_puts[i].getOperation ());
3361
3400
auto b_put_loop_nest = getParentLoopNest (b_puts[i].getOperation ());
3362
- return checkIfScfForMergeableImpl (a_put_loop_nest, b_put_loop_nest);
3401
+ putResults.push_back (
3402
+ checkIfTemporalMergeableImpl (a_put_loop_nest, b_put_loop_nest));
3363
3403
}
3404
+ std::vector<std::tuple<bool , std::string>> getResults;
3364
3405
for (unsigned i = 0 ; i < a_gets.size (); i++) {
3365
3406
auto a_get_loop_nest = getParentLoopNest (a_gets[i].getOperation ());
3366
3407
auto b_get_loop_nest = getParentLoopNest (b_gets[i].getOperation ());
3367
- return checkIfScfForMergeableImpl (a_get_loop_nest, b_get_loop_nest);
3368
- }
3369
- return mergeableToLB;
3408
+ getResults.push_back (
3409
+ checkIfTemporalMergeableImpl (a_get_loop_nest, b_get_loop_nest));
3410
+ }
3411
+ bool overallUBMergeable = true ;
3412
+ bool overallLBMergeable = true ;
3413
+ bool overallNFLMergeable = true ;
3414
+ for (auto putRes : putResults) {
3415
+ if (!std::get<0 >(putRes))
3416
+ return notMergeable;
3417
+ overallUBMergeable &= (std::get<1 >(putRes) == " UB" );
3418
+ overallLBMergeable &= (std::get<1 >(putRes) == " LB" );
3419
+ overallNFLMergeable &= (std::get<1 >(putRes) == " NFL" );
3420
+ }
3421
+ for (auto getRes : getResults) {
3422
+ if (!std::get<0 >(getRes))
3423
+ return notMergeable;
3424
+ overallUBMergeable &= (std::get<1 >(getRes) == " UB" );
3425
+ overallLBMergeable &= (std::get<1 >(getRes) == " LB" );
3426
+ overallNFLMergeable &= (std::get<1 >(getRes) == " NFL" );
3427
+ }
3428
+ if (overallNFLMergeable)
3429
+ return mergeableNoForLoop;
3430
+ else if (overallLBMergeable)
3431
+ return mergeableToLB;
3432
+ else if (overallUBMergeable)
3433
+ return mergeableToUB;
3434
+ return notMergeable;
3370
3435
}
3371
- std::vector<Operation *> getParentLoopNest (Operation *op) {
3372
- std::vector<Operation *> parent_loop_nest;
3436
+ std::vector<Block *> getParentLoopNest (Operation *op) {
3437
+ std::vector<Block *> parent_loop_nest;
3373
3438
Operation *parent = op;
3374
3439
while (parent) {
3375
- if (isa<scf::ForOp>(parent))
3376
- parent_loop_nest.push_back (parent);
3377
- else if (isa<scf::ParallelOp>(parent))
3378
- parent_loop_nest.push_back (parent);
3379
- else if (isa<air::HierarchyInterface>(parent))
3380
- parent_loop_nest.push_back (parent);
3381
- else if (isa<affine::AffineIfOp>(parent))
3382
- parent_loop_nest.push_back (parent);
3440
+ if (auto forOp = dyn_cast<scf::ForOp>(parent))
3441
+ parent_loop_nest.push_back (forOp.getBody ());
3442
+ else if (auto parOp = dyn_cast<scf::ParallelOp>(parent))
3443
+ parent_loop_nest.push_back (parOp.getBody ());
3444
+ else if (auto hierOp = dyn_cast<air::HierarchyInterface>(parent))
3445
+ parent_loop_nest.push_back (&hierOp->getRegion (0 ).front ());
3446
+ else if (auto aifOp = dyn_cast<affine::AffineIfOp>(parent)) {
3447
+ if (aifOp.getThenBlock ()->findAncestorOpInBlock (*op))
3448
+ parent_loop_nest.push_back (aifOp.getThenBlock ());
3449
+ else if (aifOp.hasElse () &&
3450
+ aifOp.getElseBlock ()->findAncestorOpInBlock (*op))
3451
+ parent_loop_nest.push_back (aifOp.getElseBlock ());
3452
+ }
3383
3453
parent = parent->getParentOp ();
3384
3454
}
3385
3455
return parent_loop_nest;
3386
3456
}
3387
- bool areEquivalentControlLoops (Operation *a, Operation *b) {
3457
+ bool areEquivalentControlLoops (Block *aBlock, Block *bBlock) {
3458
+ Operation *a = aBlock->getParentOp ();
3459
+ Operation *b = bBlock->getParentOp ();
3460
+ if (!a)
3461
+ return false ;
3462
+ if (!b)
3463
+ return false ;
3388
3464
if (isa<scf::ForOp>(a) && isa<scf::ForOp>(b)) {
3389
3465
auto a_for = dyn_cast<scf::ForOp>(a);
3390
3466
auto b_for = dyn_cast<scf::ForOp>(b);
@@ -3437,11 +3513,27 @@ class AIRFuseChannels
3437
3513
}
3438
3514
return true ;
3439
3515
} else if (isa<air::HierarchyInterface>(a) &&
3440
- isa<air::HierarchyInterface>(a )) {
3516
+ isa<air::HierarchyInterface>(b )) {
3441
3517
if (a == b)
3442
3518
return true ;
3443
- } else if (isa<affine::AffineIfOp>(a) || isa<affine::AffineIfOp>(b))
3444
- return false ;
3519
+ auto aHierSym =
3520
+ a->getAttrOfType <StringAttr>(SymbolTable::getSymbolAttrName ());
3521
+ auto bHierSym =
3522
+ b->getAttrOfType <StringAttr>(SymbolTable::getSymbolAttrName ());
3523
+ if (aHierSym && bHierSym && aHierSym.str () == bHierSym.str ())
3524
+ return true ;
3525
+ } else if (isa<affine::AffineIfOp>(a) || isa<affine::AffineIfOp>(b)) {
3526
+ if (a == b)
3527
+ return false ; // Sharing the same affine.if means spatially parallel
3528
+ // ops. Cannot merge by for loop (i.e. in time).
3529
+ auto aIf = dyn_cast<affine::AffineIfOp>(a);
3530
+ auto bIf = dyn_cast<affine::AffineIfOp>(b);
3531
+ if (aBlock == aIf.getThenBlock () && bBlock == bIf.getThenBlock ())
3532
+ return true ;
3533
+ if (aIf.hasElse () && bIf.hasElse () && aBlock == aIf.getElseBlock () &&
3534
+ bBlock == bIf.getElseBlock ())
3535
+ return true ;
3536
+ }
3445
3537
return false ;
3446
3538
}
3447
3539
void mergeChannelOps (air::ChannelInterface a, air::ChannelInterface b) {
@@ -3459,8 +3551,9 @@ class AIRFuseChannels
3459
3551
async_a.addAsyncDependency (
3460
3552
dyn_cast<air::AsyncOpInterface>(new_b).getAsyncToken ());
3461
3553
}
3462
- void mergeChannelOpsByScfFor (air::ChannelInterface a, air::ChannelInterface b,
3463
- std::string mergeByLBOrUB) {
3554
+ void mergeChannelOpsTemporally (air::ChannelInterface a,
3555
+ air::ChannelInterface b,
3556
+ std::string mergeByLBOrUB) {
3464
3557
scf::ForOp parentForOp = a->getParentOfType <scf::ForOp>();
3465
3558
while (parentForOp && parentForOp->getParentOfType <scf::ForOp>()) {
3466
3559
parentForOp = parentForOp->getParentOfType <scf::ForOp>();
@@ -3498,8 +3591,8 @@ class AIRFuseChannels
3498
3591
mergeChannelOps (a_gets[i], b_gets[i]);
3499
3592
}
3500
3593
}
3501
- void mergeChannelOpsByScfFor (air::ChannelOp chan_a, air::ChannelOp chan_b,
3502
- std::string mergeByLBOrUB) {
3594
+ void mergeChannelOpsTemporally (air::ChannelOp chan_a, air::ChannelOp chan_b,
3595
+ std::string mergeByLBOrUB) {
3503
3596
std::vector<air::ChannelPutOp> a_puts =
3504
3597
getChannelPutOpThroughSymbol (chan_a);
3505
3598
std::vector<air::ChannelPutOp> b_puts =
@@ -3509,10 +3602,10 @@ class AIRFuseChannels
3509
3602
std::vector<air::ChannelGetOp> b_gets =
3510
3603
getChannelGetOpThroughSymbol (chan_b);
3511
3604
if (!b_puts[0 ]->getParentOfType <air::HerdOp>()) {
3512
- mergeChannelOpsByScfFor (a_puts[0 ], b_puts[0 ], mergeByLBOrUB);
3605
+ mergeChannelOpsTemporally (a_puts[0 ], b_puts[0 ], mergeByLBOrUB);
3513
3606
}
3514
3607
if (!b_gets[0 ]->getParentOfType <air::HerdOp>()) {
3515
- mergeChannelOpsByScfFor (a_gets[0 ], b_gets[0 ], mergeByLBOrUB);
3608
+ mergeChannelOpsTemporally (a_gets[0 ], b_gets[0 ], mergeByLBOrUB);
3516
3609
}
3517
3610
}
3518
3611
Operation *cloneOpAndOperands (OpBuilder builder, IRMapping remap,
@@ -3550,16 +3643,18 @@ class AIRFuseChannels
3550
3643
if (a_loop_nest.size () != b_loop_nest.size ())
3551
3644
return ;
3552
3645
for (unsigned i = 0 ; i < a_loop_nest.size (); i++) {
3553
- if (auto a_for = dyn_cast<scf::ForOp>(a_loop_nest[i])) {
3554
- if (auto b_for = dyn_cast<scf::ForOp>(b_loop_nest[i])) {
3646
+ if (auto a_for = dyn_cast<scf::ForOp>(a_loop_nest[i]-> getParentOp () )) {
3647
+ if (auto b_for = dyn_cast<scf::ForOp>(b_loop_nest[i]-> getParentOp () )) {
3555
3648
for (unsigned j = 0 ; j < a_for.getBody ()->getNumArguments (); j++) {
3556
3649
remap.map (b_for.getBody ()->getArgument (j),
3557
3650
a_for.getBody ()->getArgument (j));
3558
3651
}
3559
3652
}
3560
3653
}
3561
- if (auto a_par = dyn_cast<scf::ParallelOp>(a_loop_nest[i])) {
3562
- if (auto b_par = dyn_cast<scf::ParallelOp>(b_loop_nest[i])) {
3654
+ if (auto a_par =
3655
+ dyn_cast<scf::ParallelOp>(a_loop_nest[i]->getParentOp ())) {
3656
+ if (auto b_par =
3657
+ dyn_cast<scf::ParallelOp>(b_loop_nest[i]->getParentOp ())) {
3563
3658
for (unsigned j = 0 ; j < a_par.getBody ()->getNumArguments (); j++) {
3564
3659
remap.map (b_par.getBody ()->getArgument (j),
3565
3660
a_par.getBody ()->getArgument (j));
0 commit comments