Skip to content

Commit 85f2da1

Browse files
author
Xiang Li
committed
Use min to set upperbound instead of if when limit is dynamic.
1 parent f63b2f9 commit 85f2da1

File tree

5 files changed

+221
-222
lines changed

5 files changed

+221
-222
lines changed

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,7 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
843843

844844
// Create loop to iterate every offset in gatherOffset.
845845
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
846+
Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, offsetSize).getResult();
846847
if (op.hasMask()) {
847848
SmallVector<OpFoldResult> mixedDims = op.getMixedMaskDims();
848849
OpFoldResult gatherMaskDim = mixedDims[gatherDim];
@@ -854,9 +855,17 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
854855
// If the gather mask dimension is a constant, we can use it directly.
855856
unsigned gatherMaskDimValue = gatherMaskDimIndex.value();
856857
offsetSize = std::min(offsetSize, gatherMaskDimValue);
858+
upperBound = rewriter.create<arith::ConstantIndexOp>(loc, offsetSize).getResult();
859+
} else {
860+
// Use arith::MinSIOp to get the minimum value of gatherMaskDim
861+
// and offsetSize.
862+
auto gatherMaskDimVal = cast<Value>(gatherMaskDim);
863+
auto offsetSizeVal =
864+
rewriter.create<arith::ConstantIndexOp>(loc, offsetSize);
865+
upperBound = rewriter.create<arith::MinSIOp>(loc, gatherMaskDimVal,
866+
offsetSizeVal).getResult();
857867
}
858868
}
859-
auto upperBound = rewriter.create<arith::ConstantIndexOp>(loc, offsetSize);
860869
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
861870
auto loop = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
862871

@@ -870,23 +879,6 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
870879

871880
// Load the offsetElt first.
872881
Value inductionVar = loop.getInductionVar();
873-
874-
// When there's mask for gather dimension, we need to guard the load with
875-
// gather mask.
876-
if (op.hasMask()) {
877-
SmallVector<OpFoldResult> mixedDims = op.getMixedMaskDims();
878-
OpFoldResult gatherMaskDim = mixedDims[gatherDim];
879-
if (auto gatherMaskDimVal = dyn_cast<Value>(gatherMaskDim)) {
880-
auto cmp = rewriter.create<arith::CmpIOp>(
881-
loc, arith::CmpIPredicate::slt, inductionVar, gatherMaskDimVal);
882-
auto ifOp = rewriter.create<scf::IfOp>(
883-
loc, cmp,
884-
[&](OpBuilder &b, Location loc) { b.create<scf::YieldOp>(loc); },
885-
[&](OpBuilder &b, Location loc) { b.create<scf::YieldOp>(loc); });
886-
// Set insertion point to the then body of the ifOp.
887-
rewriter.setInsertionPointToStart(ifOp.thenBlock());
888-
}
889-
}
890882
auto gatherOffsetElt = rewriter.create<tensor::ExtractOp>(
891883
loc, gatherOffset, ValueRange{inductionVar});
892884

@@ -998,6 +990,7 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
998990

999991
// Create loop to iterate every offset in gatherOffset.
1000992
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
993+
Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, offsetSize).getResult();
1001994
if (op.hasMask()) {
1002995
SmallVector<OpFoldResult> mixedDims = op.getMixedMaskDims();
1003996
OpFoldResult gatherMaskDim = mixedDims[gatherDim];
@@ -1009,9 +1002,17 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
10091002
// If the gather mask dimension is a constant, we can use it directly.
10101003
unsigned gatherMaskDimValue = gatherMaskDimIndex.value();
10111004
offsetSize = std::min(offsetSize, gatherMaskDimValue);
1005+
upperBound = rewriter.create<arith::ConstantIndexOp>(loc, offsetSize).getResult();
1006+
} else {
1007+
// Use arith::MinSIOp to get the minimum value of gatherMaskDim
1008+
// and offsetSize.
1009+
auto gatherMaskDimVal = cast<Value>(gatherMaskDim);
1010+
auto offsetSizeVal =
1011+
rewriter.create<arith::ConstantIndexOp>(loc, offsetSize);
1012+
upperBound = rewriter.create<arith::MinSIOp>(loc, gatherMaskDimVal,
1013+
offsetSizeVal).getResult();
10121014
}
10131015
}
1014-
auto upperBound = rewriter.create<arith::ConstantIndexOp>(loc, offsetSize);
10151016
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
10161017
auto loop = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
10171018

@@ -1021,24 +1022,6 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
10211022
// Load the offsetElt first.
10221023
Value inductionVar = loop.getInductionVar();
10231024

1024-
// When there's mask for gather dimension, we need to guard the load with
1025-
// gather mask.
1026-
if (op.hasMask()) {
1027-
SmallVector<OpFoldResult> mixedDims = op.getMixedMaskDims();
1028-
OpFoldResult gatherMaskDim = mixedDims[gatherDim];
1029-
if (auto gatherMaskDimVal = dyn_cast<Value>(gatherMaskDim)) {
1030-
1031-
auto cmp = rewriter.create<arith::CmpIOp>(
1032-
loc, arith::CmpIPredicate::slt, inductionVar, gatherMaskDimVal);
1033-
auto ifOp = rewriter.create<scf::IfOp>(
1034-
loc, cmp,
1035-
[&](OpBuilder &b, Location loc) { b.create<scf::YieldOp>(loc); },
1036-
[&](OpBuilder &b, Location loc) { b.create<scf::YieldOp>(loc); });
1037-
// Set insertion point to the then body of the ifOp.
1038-
rewriter.setInsertionPointToStart(ifOp.thenBlock());
1039-
}
1040-
}
1041-
10421025
auto gatherOffsetElt = rewriter.create<tensor::ExtractOp>(
10431026
loc, gatherOffset, ValueRange{inductionVar});
10441027

0 commit comments

Comments
 (0)