@@ -843,6 +843,7 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
843
843
844
844
// Create loop to iterate every offset in gatherOffset.
845
845
auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
846
+ Value upperBound = rewriter.create <arith::ConstantIndexOp>(loc, offsetSize).getResult ();
846
847
if (op.hasMask ()) {
847
848
SmallVector<OpFoldResult> mixedDims = op.getMixedMaskDims ();
848
849
OpFoldResult gatherMaskDim = mixedDims[gatherDim];
@@ -854,9 +855,17 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
854
855
// If the gather mask dimension is a constant, we can use it directly.
855
856
unsigned gatherMaskDimValue = gatherMaskDimIndex.value ();
856
857
offsetSize = std::min (offsetSize, gatherMaskDimValue);
858
+ upperBound = rewriter.create <arith::ConstantIndexOp>(loc, offsetSize).getResult ();
859
+ } else {
860
+ // Use arith::MinSIOp to get the minimum value of gatherMaskDim
861
+ // and offsetSize.
862
+ auto gatherMaskDimVal = cast<Value>(gatherMaskDim);
863
+ auto offsetSizeVal =
864
+ rewriter.create <arith::ConstantIndexOp>(loc, offsetSize);
865
+ upperBound = rewriter.create <arith::MinSIOp>(loc, gatherMaskDimVal,
866
+ offsetSizeVal).getResult ();
857
867
}
858
868
}
859
- auto upperBound = rewriter.create <arith::ConstantIndexOp>(loc, offsetSize);
860
869
auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
861
870
auto loop = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step);
862
871
@@ -870,23 +879,6 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
870
879
871
880
// Load the offsetElt first.
872
881
Value inductionVar = loop.getInductionVar ();
873
-
874
- // When there's mask for gather dimension, we need to guard the load with
875
- // gather mask.
876
- if (op.hasMask ()) {
877
- SmallVector<OpFoldResult> mixedDims = op.getMixedMaskDims ();
878
- OpFoldResult gatherMaskDim = mixedDims[gatherDim];
879
- if (auto gatherMaskDimVal = dyn_cast<Value>(gatherMaskDim)) {
880
- auto cmp = rewriter.create <arith::CmpIOp>(
881
- loc, arith::CmpIPredicate::slt, inductionVar, gatherMaskDimVal);
882
- auto ifOp = rewriter.create <scf::IfOp>(
883
- loc, cmp,
884
- [&](OpBuilder &b, Location loc) { b.create <scf::YieldOp>(loc); },
885
- [&](OpBuilder &b, Location loc) { b.create <scf::YieldOp>(loc); });
886
- // Set insertion point to the then body of the ifOp.
887
- rewriter.setInsertionPointToStart (ifOp.thenBlock ());
888
- }
889
- }
890
882
auto gatherOffsetElt = rewriter.create <tensor::ExtractOp>(
891
883
loc, gatherOffset, ValueRange{inductionVar});
892
884
@@ -998,6 +990,7 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
998
990
999
991
// Create loop to iterate every offset in gatherOffset.
1000
992
auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
993
+ Value upperBound = rewriter.create <arith::ConstantIndexOp>(loc, offsetSize).getResult ();
1001
994
if (op.hasMask ()) {
1002
995
SmallVector<OpFoldResult> mixedDims = op.getMixedMaskDims ();
1003
996
OpFoldResult gatherMaskDim = mixedDims[gatherDim];
@@ -1009,9 +1002,17 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
1009
1002
// If the gather mask dimension is a constant, we can use it directly.
1010
1003
unsigned gatherMaskDimValue = gatherMaskDimIndex.value ();
1011
1004
offsetSize = std::min (offsetSize, gatherMaskDimValue);
1005
+ upperBound = rewriter.create <arith::ConstantIndexOp>(loc, offsetSize).getResult ();
1006
+ } else {
1007
+ // Use arith::MinSIOp to get the minimum value of gatherMaskDim
1008
+ // and offsetSize.
1009
+ auto gatherMaskDimVal = cast<Value>(gatherMaskDim);
1010
+ auto offsetSizeVal =
1011
+ rewriter.create <arith::ConstantIndexOp>(loc, offsetSize);
1012
+ upperBound = rewriter.create <arith::MinSIOp>(loc, gatherMaskDimVal,
1013
+ offsetSizeVal).getResult ();
1012
1014
}
1013
1015
}
1014
- auto upperBound = rewriter.create <arith::ConstantIndexOp>(loc, offsetSize);
1015
1016
auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
1016
1017
auto loop = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step);
1017
1018
@@ -1021,24 +1022,6 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
1021
1022
// Load the offsetElt first.
1022
1023
Value inductionVar = loop.getInductionVar ();
1023
1024
1024
- // When there's mask for gather dimension, we need to guard the load with
1025
- // gather mask.
1026
- if (op.hasMask ()) {
1027
- SmallVector<OpFoldResult> mixedDims = op.getMixedMaskDims ();
1028
- OpFoldResult gatherMaskDim = mixedDims[gatherDim];
1029
- if (auto gatherMaskDimVal = dyn_cast<Value>(gatherMaskDim)) {
1030
-
1031
- auto cmp = rewriter.create <arith::CmpIOp>(
1032
- loc, arith::CmpIPredicate::slt, inductionVar, gatherMaskDimVal);
1033
- auto ifOp = rewriter.create <scf::IfOp>(
1034
- loc, cmp,
1035
- [&](OpBuilder &b, Location loc) { b.create <scf::YieldOp>(loc); },
1036
- [&](OpBuilder &b, Location loc) { b.create <scf::YieldOp>(loc); });
1037
- // Set insertion point to the then body of the ifOp.
1038
- rewriter.setInsertionPointToStart (ifOp.thenBlock ());
1039
- }
1040
- }
1041
-
1042
1025
auto gatherOffsetElt = rewriter.create <tensor::ExtractOp>(
1043
1026
loc, gatherOffset, ValueRange{inductionVar});
1044
1027
0 commit comments