Skip to content

Commit 0ff42e8

Browse files
authored
AIRChannelCanonicalization: fixup a bug where canonicalization is skipped by mistake (Xilinx#1053)
* Fixup a bug where, when num_dims > max_num_dims and highest-dimension repeat is enabled, wrap-and-stride canonicalization is skipped * Fixup test for AIE1 target: AIE1 doesn't support md dma at shim
1 parent e56f28e commit 0ff42e8

File tree

6 files changed

+107
-101
lines changed

6 files changed

+107
-101
lines changed

mlir/include/air/Transform/AIRDependencyScheduleOpt.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ void populateAIRLoopIndexCanonicalizationPatterns(RewritePatternSet &patterns);
7373
// Populate patterns for canonicalizing offsets, sizes and strides in air
7474
// channel_interface operations.
7575
void populateAIRCanonicalizeChannelWrapAndStridePatterns(
76-
RewritePatternSet &patterns, int &maxSize, bool &enableRepeatAtHighestDim);
76+
RewritePatternSet &patterns, int &maxSize, int &maxNumDims,
77+
bool &enableRepeatAtHighestDim);
7778

7879
// Apply AIRSpecializeChannelWrapAndStridePattern on region.
7980
void applyAIRSpecializeChannelWrapAndStridePattern(

mlir/include/air/Util/Util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ int findLargestFactor(int num, int max);
166166

167167
// Canonicalize wrap and stride lists, by removing redundant dimensions.
168168
LogicalResult canonicalizeWrapAndStrideList(
169-
OpBuilder builder, SmallVector<Value> &offsets, SmallVector<Value> &sizes,
169+
OpBuilder &builder, SmallVector<Value> &offsets, SmallVector<Value> &sizes,
170170
SmallVector<Value> &strides, int memref_volume, int maxSize = -1);
171171

172172
// If wrap-and-stride lists are empty, populate them with default data access

mlir/lib/Conversion/AIRToAIEPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,10 +736,11 @@ void lowerAirExecute(AIE::DeviceOp d) {
736736
auto ctx = d->getContext();
737737
RewritePatternSet patterns(ctx);
738738
int maxSize = isa<AIE::AIE1TargetModel>(AIE::getTargetModel(d)) ? -1 : 1023;
739+
int maxNumDims = isa<AIE::AIE1TargetModel>(AIE::getTargetModel(d)) ? 1 : 4;
739740
patterns.insert<LowerAIRExecutePattern>(ctx);
740741
bool enableRepeatAtHighestDim = false;
741742
air::populateAIRCanonicalizeChannelWrapAndStridePatterns(
742-
patterns, maxSize, enableRepeatAtHighestDim);
743+
patterns, maxSize, maxNumDims, enableRepeatAtHighestDim);
743744
(void)applyPatternsGreedily(d, std::move(patterns));
744745
}
745746

mlir/lib/Transform/AIRDependencyScheduleOpt.cpp

Lines changed: 75 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,123 +2247,100 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor
22472247
private:
22482248
};
22492249

2250-
struct AIRCanonicalizeChannelPutOpWrapAndStrideList
2251-
: public OpRewritePattern<air::ChannelPutOp> {
2252-
using OpRewritePattern<air::ChannelPutOp>::OpRewritePattern;
2253-
2254-
AIRCanonicalizeChannelPutOpWrapAndStrideList(MLIRContext *ctx, int &maxSize,
2255-
bool &enableRepeatAtHighestDim)
2256-
: OpRewritePattern(ctx), maxSize(maxSize),
2250+
/// This pattern canonicalizes the offset/size/stride lists of `OpT` channel
2251+
/// put/get operations.
2252+
///
2253+
/// **Main transformations**:
2254+
/// 1. Detect whether a "highest-dimension repeat" pattern is active (special
2255+
/// case where the highest dimension repeats and requires padding to
2256+
/// `maxNumDims`).
2257+
/// 2. Canonicalize the wrap-and-stride list by invoking
2258+
/// `canonicalizeWrapAndStrideList()`, which normalizes
2259+
/// offsets/sizes/strides.
2260+
/// 3. If highest-dimension repeat is active, extend the rank of
2261+
/// offsets/sizes/strides to `maxNumDims` by inserting zeros/ones.
2262+
/// 4. Recreate the `OpT` operation with the canonicalized parameters and
2263+
/// replace the original operation.
2264+
template <typename OpT>
2265+
struct AIRCanonicalizeChannelPutGetOpWrapAndStrideList
2266+
: public OpRewritePattern<OpT> {
2267+
using OpRewritePattern<OpT>::OpRewritePattern;
2268+
2269+
AIRCanonicalizeChannelPutGetOpWrapAndStrideList(
2270+
MLIRContext *ctx, int &maxSize, int &maxNumDims,
2271+
bool &enableRepeatAtHighestDim)
2272+
: OpRewritePattern<OpT>(ctx), maxSize(maxSize), maxNumDims(maxNumDims),
22572273
enableRepeatAtHighestDim(enableRepeatAtHighestDim) {}
22582274

2259-
LogicalResult matchAndRewrite(air::ChannelPutOp op,
2275+
LogicalResult matchAndRewrite(OpT op,
22602276
PatternRewriter &rewriter) const override {
2261-
2277+
// Collect async token types and dependencies if the op is asynchronous.
22622278
SmallVector<Value, 1> deps;
22632279
SmallVector<Type, 1> tys;
22642280
if (isAsyncOp(op)) {
22652281
tys.push_back(air::AsyncTokenType::get(op->getContext()));
22662282
deps = op.getAsyncDependencies();
22672283
}
22682284

2285+
// Extract offsets, sizes, and strides from the op.
22692286
SmallVector<Value> offsets = op.getOffsets();
22702287
SmallVector<Value> sizes = op.getSizes();
22712288
SmallVector<Value> strides = op.getStrides();
22722289

2273-
// When highest-dimension repeat is active, (1) enableRepeatAtHighestDim
2274-
// option is switched on, (2) wrap-and-stride list isn't empty (i.e. data
2275-
// isn't 1-d streamed in), (3) highest stride is zero, and (4) highest wrap
2276-
// is not one.
2290+
// Detect if highest-dimension repeat logic should be applied.
2291+
// This is true when:
2292+
// (1) The option enableRepeatAtHighestDim is set,
2293+
// (2) The stride list is not empty,
2294+
// (3) The highest (first) stride is 0, indicating repeat dimension,
2295+
// (4) The highest (first) size is not 1, indicating non-zero repetition.
22772296
bool highestDimRepeatActive = enableRepeatAtHighestDim &&
22782297
!strides.empty() &&
22792298
*getConstantIntValue(strides.front()) == 0 &&
22802299
*getConstantIntValue(sizes.front()) != 1;
2281-
2282-
if (highestDimRepeatActive) {
2283-
// If repeat is enabled at the highest dimension, then the highest
2284-
// dimension must be preserved.
2300+
// If highest-dimension repeat is active but the op already has the maximum
2301+
// number of dimensions, no rewrite is needed.
2302+
if (highestDimRepeatActive && (int)offsets.size() == maxNumDims) {
22852303
return failure();
2286-
}
2287-
2288-
if (failed(canonicalizeWrapAndStrideList(
2289-
rewriter, offsets, sizes, strides,
2290-
air::getTensorVolume(op.getMemref().getType()), maxSize)))
2291-
return failure();
2292-
2293-
auto new_op = rewriter.create<air::ChannelPutOp>(
2294-
op->getLoc(), tys, deps, op.getChanName(), op.getIndices(),
2295-
op.getMemref(), offsets, sizes, strides);
2296-
new_op->setAttrs(op->getDiscardableAttrDictionary());
2297-
for (unsigned i = 0; i < op->getResults().size(); i++)
2298-
op->getResults()[i].replaceAllUsesWith(new_op->getResults()[i]);
2299-
2300-
rewriter.eraseOp(op);
2301-
2302-
return success();
2303-
}
2304-
2305-
private:
2306-
int &maxSize;
2307-
bool &enableRepeatAtHighestDim;
2308-
};
2309-
2310-
struct AIRCanonicalizeChannelGetOpWrapAndStrideList
2311-
: public OpRewritePattern<air::ChannelGetOp> {
2312-
using OpRewritePattern<air::ChannelGetOp>::OpRewritePattern;
2313-
2314-
AIRCanonicalizeChannelGetOpWrapAndStrideList(MLIRContext *ctx, int &maxSize,
2315-
bool &enableRepeatAtHighestDim)
2316-
: OpRewritePattern(ctx), maxSize(maxSize),
2317-
enableRepeatAtHighestDim(enableRepeatAtHighestDim) {}
2318-
2319-
LogicalResult matchAndRewrite(air::ChannelGetOp op,
2320-
PatternRewriter &rewriter) const override {
2321-
2322-
SmallVector<Value, 1> deps;
2323-
SmallVector<Type, 1> tys;
2324-
if (isAsyncOp(op)) {
2325-
tys.push_back(air::AsyncTokenType::get(op->getContext()));
2326-
deps = op.getAsyncDependencies();
2327-
}
2328-
2329-
SmallVector<Value> offsets = op.getOffsets();
2330-
SmallVector<Value> sizes = op.getSizes();
2331-
SmallVector<Value> strides = op.getStrides();
2332-
2333-
// When highest-dimension repeat is active, (1) enableRepeatAtHighestDim
2334-
// option is switched on, (2) wrap-and-stride list isn't empty (i.e. data
2335-
// isn't 1-d streamed in), (3) highest stride is zero, and (4) highest wrap
2336-
// is not one.
2337-
bool highestDimRepeatActive = enableRepeatAtHighestDim &&
2338-
!strides.empty() &&
2339-
*getConstantIntValue(strides.front()) == 0 &&
2340-
*getConstantIntValue(sizes.front()) != 1;
2304+
} else {
2305+
// Canonicalize offsets/sizes/strides using a helper function.
2306+
if (failed(canonicalizeWrapAndStrideList(
2307+
rewriter, offsets, sizes, strides,
2308+
air::getTensorVolume(op.getMemref().getType()), maxSize)))
2309+
return failure();
23412310

2342-
if (highestDimRepeatActive) {
2343-
// If repeat is enabled at the highest dimension, then the highest
2344-
// dimension must be preserved.
2345-
return failure();
2311+
// When highest-dimension repeat is active, pad offsets/sizes/strides to
2312+
// match maxNumDims by inserting:
2313+
// - offset = 0
2314+
// - size = 1
2315+
// - stride = 0
2316+
if (highestDimRepeatActive) {
2317+
auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
2318+
auto oneIdx = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 1);
2319+
while ((int)offsets.size() < maxNumDims) {
2320+
offsets.insert(offsets.begin() + 1, zeroIdx);
2321+
}
2322+
while ((int)sizes.size() < maxNumDims) {
2323+
sizes.insert(sizes.begin() + 1, oneIdx);
2324+
}
2325+
while ((int)strides.size() < maxNumDims) {
2326+
strides.insert(strides.begin() + 1, zeroIdx);
2327+
}
2328+
}
23462329
}
23472330

2348-
if (failed(canonicalizeWrapAndStrideList(
2349-
rewriter, offsets, sizes, strides,
2350-
air::getTensorVolume(op.getMemref().getType()), maxSize)))
2351-
return failure();
2352-
2353-
auto new_op = rewriter.create<air::ChannelGetOp>(
2354-
op->getLoc(), tys, deps, op.getChanName(), op.getIndices(),
2355-
op.getMemref(), offsets, sizes, strides);
2356-
new_op->setAttrs(op->getDiscardableAttrDictionary());
2357-
for (unsigned i = 0; i < op->getResults().size(); i++)
2358-
op->getResults()[i].replaceAllUsesWith(new_op->getResults()[i]);
2359-
2360-
rewriter.eraseOp(op);
2331+
// Create a new op with the canonicalized attributes and operands.
2332+
auto attrs = op->getDiscardableAttrDictionary();
2333+
auto new_op = rewriter.replaceOpWithNewOp<OpT>(
2334+
op, tys, deps, op.getChanName(), op.getIndices(), op.getMemref(),
2335+
offsets, sizes, strides);
2336+
new_op->setAttrs(attrs);
23612337

23622338
return success();
23632339
}
23642340

23652341
private:
23662342
int &maxSize;
2343+
int &maxNumDims;
23672344
bool &enableRepeatAtHighestDim;
23682345
};
23692346

@@ -3171,7 +3148,7 @@ LogicalResult AIRSpecializeChannelWrapAndStrideImpl(
31713148
// Canonicalize wrap and stride list to remove redundant dimensions
31723149
RewritePatternSet preproc_wns_patterns(ctx);
31733150
populateAIRCanonicalizeChannelWrapAndStridePatterns(
3174-
preproc_wns_patterns, maxSize, enableRepeatAtHighestDim);
3151+
preproc_wns_patterns, maxSize, maxNumDims, enableRepeatAtHighestDim);
31753152
(void)applyPatternsGreedily(*region, std::move(preproc_wns_patterns));
31763153

31773154
RewritePatternSet patterns(ctx);
@@ -3196,10 +3173,9 @@ LogicalResult AIRSpecializeChannelWrapAndStrideImpl(
31963173

31973174
// Canonicalize wrap and stride list to remove redundant dimensions
31983175
RewritePatternSet cano_patterns(ctx);
3199-
populateAIRCanonicalizeChannelWrapAndStridePatterns(cano_patterns, maxSize,
3200-
enableRepeatAtHighestDim);
3176+
populateAIRCanonicalizeChannelWrapAndStridePatterns(
3177+
cano_patterns, maxSize, maxNumDims, enableRepeatAtHighestDim);
32013178
ExecuteOp::getCanonicalizationPatterns(cano_patterns, ctx);
3202-
// WaitAllOp::getCanonicalizationPatterns(cano_patterns, ctx);
32033179
(void)applyPatternsGreedily(*region, std::move(cano_patterns));
32043180

32053181
return success();
@@ -6733,11 +6709,13 @@ void populateAIRLoopIndexCanonicalizationPatterns(RewritePatternSet &patterns) {
67336709
}
67346710

67356711
void populateAIRCanonicalizeChannelWrapAndStridePatterns(
6736-
RewritePatternSet &patterns, int &maxSize, bool &enableRepeatAtHighestDim) {
6712+
RewritePatternSet &patterns, int &maxSize, int &maxNumDims,
6713+
bool &enableRepeatAtHighestDim) {
67376714
MLIRContext *ctx = patterns.getContext();
6738-
patterns.insert<AIRCanonicalizeChannelPutOpWrapAndStrideList,
6739-
AIRCanonicalizeChannelGetOpWrapAndStrideList>(
6740-
ctx, maxSize, enableRepeatAtHighestDim);
6715+
patterns.insert<
6716+
AIRCanonicalizeChannelPutGetOpWrapAndStrideList<air::ChannelPutOp>,
6717+
AIRCanonicalizeChannelPutGetOpWrapAndStrideList<air::ChannelGetOp>>(
6718+
ctx, maxSize, maxNumDims, enableRepeatAtHighestDim);
67416719
}
67426720

67436721
void applyAIRSpecializeChannelWrapAndStridePattern(

mlir/lib/Util/Util.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1086,11 +1086,12 @@ int air::findLargestFactor(int num, int max) {
10861086

10871087
// Canonicalize wrap and stride lists by removing redundant dimensions.
10881088
LogicalResult air::canonicalizeWrapAndStrideList(
1089-
OpBuilder builder, SmallVector<Value> &offsets, SmallVector<Value> &sizes,
1089+
OpBuilder &builder, SmallVector<Value> &offsets, SmallVector<Value> &sizes,
10901090
SmallVector<Value> &strides, int memref_volume, int maxSize) {
10911091
// AIE2 hardware constraints. TODO: import these info from target model.
10921092
const int AIE2_STRIDE_UPPER_BOUND = 1048576;
10931093
bool listsHaveChanged = false;
1094+
OpBuilder::InsertionGuard guard(builder);
10941095
// Match offsets size with sizes and strides
10951096
auto max_dim_size =
10961097
std::max(std::max(offsets.size(), sizes.size()), strides.size());

mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,4 +855,29 @@ module {
855855
}
856856
return
857857
}
858+
859+
// Canonicalizing repeat dimension at highest dimension.
860+
861+
// CHECK-LABEL: func15
862+
// CHECK: air.channel.put async{{.*}}@channel_0[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}, %c320{{.*}}] [%c2{{.*}}, %c1{{.*}}, %c512{{.*}}, %c64{{.*}}] [%c0{{.*}}, %c0{{.*}}, %c512{{.*}}, %c1{{.*}}])
863+
// NPUTILED-LABEL: func15
864+
// NPUTILED: air.channel.put async{{.*}}@channel_0[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}, %c320{{.*}}] [%c2{{.*}}, %c1{{.*}}, %c512{{.*}}, %c64{{.*}}] [%c0{{.*}}, %c0{{.*}}, %c512{{.*}}, %c1{{.*}}])
865+
// AIE1-LABEL: func15
866+
// AIE1: air.channel.put async{{.*}}@channel_0[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c320{{.*}}] [%c2{{.*}}, %c512{{.*}}, %c64{{.*}}] [%c0{{.*}}, %c512{{.*}}, %c1{{.*}}])
867+
868+
func.func @func15(%arg0: memref<512x512xbf16>) {
869+
%0 = air.launch async () in () args(%arg8=%arg0) : memref<512x512xbf16> {
870+
%c65536 = arith.constant 65536 : index
871+
%c4 = arith.constant 4 : index
872+
%c256 = arith.constant 256 : index
873+
%c64 = arith.constant 64 : index
874+
%c128 = arith.constant 128 : index
875+
%c512 = arith.constant 512 : index
876+
%c0 = arith.constant 0 : index
877+
%c1 = arith.constant 1 : index
878+
%c2 = arith.constant 2 : index
879+
%1 = air.channel.put async @channel_0[%c0, %c0] (%arg8[%c0, %c0, %c1, %c0, %c256] [%c2, %c4, %c1, %c128, %c64] [%c0, %c65536, %c64, %c512, %c1]) {id = 6 : i32, metadataArray = [{base = "air_channel_13_0", index = 0 : i32}, {base = "air_channel_13_1", index = 1 : i32}, {base = "air_channel_13_2", index = 2 : i32}, {base = "air_channel_13_3", index = 3 : i32}]} : (memref<512x512xbf16>)
880+
}
881+
return
882+
}
858883
}

0 commit comments

Comments
 (0)