@@ -2247,123 +2247,100 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor
22472247private:
22482248};
22492249
2250- struct AIRCanonicalizeChannelPutOpWrapAndStrideList
2251- : public OpRewritePattern<air::ChannelPutOp> {
2252- using OpRewritePattern<air::ChannelPutOp>::OpRewritePattern;
2253-
2254- AIRCanonicalizeChannelPutOpWrapAndStrideList (MLIRContext *ctx, int &maxSize,
2255- bool &enableRepeatAtHighestDim)
2256- : OpRewritePattern(ctx), maxSize(maxSize),
2250+ // / This pattern canonicalizes the offset/size/stride lists of `OpT` channel
2251+ // / put/get operations.
2252+ // /
2253+ // / **Main transformations**:
2254+ // / 1. Detect whether a "highest-dimension repeat" pattern is active (special
2255+ // / case where the highest dimension repeats and requires padding to
2256+ // / `maxNumDims`).
2257+ // / 2. Canonicalize the wrap-and-stride list by invoking
2258+ // / `canonicalizeWrapAndStrideList()`, which normalizes
2259+ // / offsets/sizes/strides.
2260+ // / 3. If highest-dimension repeat is active, extend the rank of
2261+ // / offsets/sizes/strides to `maxNumDims` by inserting zeros/ones.
2262+ // / 4. Recreate the `OpT` operation with the canonicalized parameters and
2263+ // / replace the original operation.
2264+ template <typename OpT>
2265+ struct AIRCanonicalizeChannelPutGetOpWrapAndStrideList
2266+ : public OpRewritePattern<OpT> {
2267+ using OpRewritePattern<OpT>::OpRewritePattern;
2268+
2269+ AIRCanonicalizeChannelPutGetOpWrapAndStrideList (
2270+ MLIRContext *ctx, int &maxSize, int &maxNumDims,
2271+ bool &enableRepeatAtHighestDim)
2272+ : OpRewritePattern<OpT>(ctx), maxSize(maxSize), maxNumDims(maxNumDims),
22572273 enableRepeatAtHighestDim (enableRepeatAtHighestDim) {}
22582274
2259- LogicalResult matchAndRewrite (air::ChannelPutOp op,
2275+ LogicalResult matchAndRewrite (OpT op,
22602276 PatternRewriter &rewriter) const override {
2261-
2277+ // Collect async token types and dependencies if the op is asynchronous.
22622278 SmallVector<Value, 1 > deps;
22632279 SmallVector<Type, 1 > tys;
22642280 if (isAsyncOp (op)) {
22652281 tys.push_back (air::AsyncTokenType::get (op->getContext ()));
22662282 deps = op.getAsyncDependencies ();
22672283 }
22682284
2285+ // Extract offsets, sizes, and strides from the op.
22692286 SmallVector<Value> offsets = op.getOffsets ();
22702287 SmallVector<Value> sizes = op.getSizes ();
22712288 SmallVector<Value> strides = op.getStrides ();
22722289
2273- // When highest-dimension repeat is active, (1) enableRepeatAtHighestDim
2274- // option is switched on, (2) wrap-and-stride list isn't empty (i.e. data
2275- // isn't 1-d streamed in), (3) highest stride is zero, and (4) highest wrap
2276- // is not one.
2290+ // Detect if highest-dimension repeat logic should be applied.
2291+ // This is true when:
2292+ // (1) The option enableRepeatAtHighestDim is set,
2293+ // (2) The stride list is not empty,
2294+ // (3) The highest (first) stride is 0, indicating repeat dimension,
2295+ // (4) The highest (first) size is not 1, indicating non-zero repetition.
22772296 bool highestDimRepeatActive = enableRepeatAtHighestDim &&
22782297 !strides.empty () &&
22792298 *getConstantIntValue (strides.front ()) == 0 &&
22802299 *getConstantIntValue (sizes.front ()) != 1 ;
2281-
2282- if (highestDimRepeatActive) {
2283- // If repeat is enabled at the highest dimension, then the highest
2284- // dimension must be preserved.
2300+ // If highest-dimension repeat is active but the op already has the maximum
2301+ // number of dimensions, no rewrite is needed.
2302+ if (highestDimRepeatActive && (int )offsets.size () == maxNumDims) {
22852303 return failure ();
2286- }
2287-
2288- if (failed (canonicalizeWrapAndStrideList (
2289- rewriter, offsets, sizes, strides,
2290- air::getTensorVolume (op.getMemref ().getType ()), maxSize)))
2291- return failure ();
2292-
2293- auto new_op = rewriter.create <air::ChannelPutOp>(
2294- op->getLoc (), tys, deps, op.getChanName (), op.getIndices (),
2295- op.getMemref (), offsets, sizes, strides);
2296- new_op->setAttrs (op->getDiscardableAttrDictionary ());
2297- for (unsigned i = 0 ; i < op->getResults ().size (); i++)
2298- op->getResults ()[i].replaceAllUsesWith (new_op->getResults ()[i]);
2299-
2300- rewriter.eraseOp (op);
2301-
2302- return success ();
2303- }
2304-
2305- private:
2306- int &maxSize;
2307- bool &enableRepeatAtHighestDim;
2308- };
2309-
2310- struct AIRCanonicalizeChannelGetOpWrapAndStrideList
2311- : public OpRewritePattern<air::ChannelGetOp> {
2312- using OpRewritePattern<air::ChannelGetOp>::OpRewritePattern;
2313-
2314- AIRCanonicalizeChannelGetOpWrapAndStrideList (MLIRContext *ctx, int &maxSize,
2315- bool &enableRepeatAtHighestDim)
2316- : OpRewritePattern(ctx), maxSize(maxSize),
2317- enableRepeatAtHighestDim (enableRepeatAtHighestDim) {}
2318-
2319- LogicalResult matchAndRewrite (air::ChannelGetOp op,
2320- PatternRewriter &rewriter) const override {
2321-
2322- SmallVector<Value, 1 > deps;
2323- SmallVector<Type, 1 > tys;
2324- if (isAsyncOp (op)) {
2325- tys.push_back (air::AsyncTokenType::get (op->getContext ()));
2326- deps = op.getAsyncDependencies ();
2327- }
2328-
2329- SmallVector<Value> offsets = op.getOffsets ();
2330- SmallVector<Value> sizes = op.getSizes ();
2331- SmallVector<Value> strides = op.getStrides ();
2332-
2333- // When highest-dimension repeat is active, (1) enableRepeatAtHighestDim
2334- // option is switched on, (2) wrap-and-stride list isn't empty (i.e. data
2335- // isn't 1-d streamed in), (3) highest stride is zero, and (4) highest wrap
2336- // is not one.
2337- bool highestDimRepeatActive = enableRepeatAtHighestDim &&
2338- !strides.empty () &&
2339- *getConstantIntValue (strides.front ()) == 0 &&
2340- *getConstantIntValue (sizes.front ()) != 1 ;
2304+ } else {
2305+ // Canonicalize offsets/sizes/strides using a helper function.
2306+ if (failed (canonicalizeWrapAndStrideList (
2307+ rewriter, offsets, sizes, strides,
2308+ air::getTensorVolume (op.getMemref ().getType ()), maxSize)))
2309+ return failure ();
23412310
2342- if (highestDimRepeatActive) {
2343- // If repeat is enabled at the highest dimension, then the highest
2344- // dimension must be preserved.
2345- return failure ();
2311+ // When highest-dimension repeat is active, pad offsets/sizes/strides to
2312+ // match maxNumDims by inserting:
2313+ // - offset = 0
2314+ // - size = 1
2315+ // - stride = 0
2316+ if (highestDimRepeatActive) {
2317+ auto zeroIdx = rewriter.create <arith::ConstantIndexOp>(op->getLoc (), 0 );
2318+ auto oneIdx = rewriter.create <arith::ConstantIndexOp>(op->getLoc (), 1 );
2319+ while ((int )offsets.size () < maxNumDims) {
2320+ offsets.insert (offsets.begin () + 1 , zeroIdx);
2321+ }
2322+ while ((int )sizes.size () < maxNumDims) {
2323+ sizes.insert (sizes.begin () + 1 , oneIdx);
2324+ }
2325+ while ((int )strides.size () < maxNumDims) {
2326+ strides.insert (strides.begin () + 1 , zeroIdx);
2327+ }
2328+ }
23462329 }
23472330
2348- if (failed (canonicalizeWrapAndStrideList (
2349- rewriter, offsets, sizes, strides,
2350- air::getTensorVolume (op.getMemref ().getType ()), maxSize)))
2351- return failure ();
2352-
2353- auto new_op = rewriter.create <air::ChannelGetOp>(
2354- op->getLoc (), tys, deps, op.getChanName (), op.getIndices (),
2355- op.getMemref (), offsets, sizes, strides);
2356- new_op->setAttrs (op->getDiscardableAttrDictionary ());
2357- for (unsigned i = 0 ; i < op->getResults ().size (); i++)
2358- op->getResults ()[i].replaceAllUsesWith (new_op->getResults ()[i]);
2359-
2360- rewriter.eraseOp (op);
2331+ // Create a new op with the canonicalized attributes and operands.
2332+ auto attrs = op->getDiscardableAttrDictionary ();
2333+ auto new_op = rewriter.replaceOpWithNewOp <OpT>(
2334+ op, tys, deps, op.getChanName (), op.getIndices (), op.getMemref (),
2335+ offsets, sizes, strides);
2336+ new_op->setAttrs (attrs);
23612337
23622338 return success ();
23632339 }
23642340
23652341private:
23662342 int &maxSize;
2343+ int &maxNumDims;
23672344 bool &enableRepeatAtHighestDim;
23682345};
23692346
@@ -3171,7 +3148,7 @@ LogicalResult AIRSpecializeChannelWrapAndStrideImpl(
31713148 // Canonicalize wrap and stride list to remove redundant dimensions
31723149 RewritePatternSet preproc_wns_patterns (ctx);
31733150 populateAIRCanonicalizeChannelWrapAndStridePatterns (
3174- preproc_wns_patterns, maxSize, enableRepeatAtHighestDim);
3151+ preproc_wns_patterns, maxSize, maxNumDims, enableRepeatAtHighestDim);
31753152 (void )applyPatternsGreedily (*region, std::move (preproc_wns_patterns));
31763153
31773154 RewritePatternSet patterns (ctx);
@@ -3196,10 +3173,9 @@ LogicalResult AIRSpecializeChannelWrapAndStrideImpl(
31963173
31973174 // Canonicalize wrap and stride list to remove redundant dimensions
31983175 RewritePatternSet cano_patterns (ctx);
3199- populateAIRCanonicalizeChannelWrapAndStridePatterns (cano_patterns, maxSize,
3200- enableRepeatAtHighestDim);
3176+ populateAIRCanonicalizeChannelWrapAndStridePatterns (
3177+ cano_patterns, maxSize, maxNumDims, enableRepeatAtHighestDim);
32013178 ExecuteOp::getCanonicalizationPatterns (cano_patterns, ctx);
3202- // WaitAllOp::getCanonicalizationPatterns(cano_patterns, ctx);
32033179 (void )applyPatternsGreedily (*region, std::move (cano_patterns));
32043180
32053181 return success ();
@@ -6733,11 +6709,13 @@ void populateAIRLoopIndexCanonicalizationPatterns(RewritePatternSet &patterns) {
67336709}
67346710
67356711void populateAIRCanonicalizeChannelWrapAndStridePatterns (
6736- RewritePatternSet &patterns, int &maxSize, bool &enableRepeatAtHighestDim) {
6712+ RewritePatternSet &patterns, int &maxSize, int &maxNumDims,
6713+ bool &enableRepeatAtHighestDim) {
67376714 MLIRContext *ctx = patterns.getContext ();
6738- patterns.insert <AIRCanonicalizeChannelPutOpWrapAndStrideList,
6739- AIRCanonicalizeChannelGetOpWrapAndStrideList>(
6740- ctx, maxSize, enableRepeatAtHighestDim);
6715+ patterns.insert <
6716+ AIRCanonicalizeChannelPutGetOpWrapAndStrideList<air::ChannelPutOp>,
6717+ AIRCanonicalizeChannelPutGetOpWrapAndStrideList<air::ChannelGetOp>>(
6718+ ctx, maxSize, maxNumDims, enableRepeatAtHighestDim);
67416719}
67426720
67436721void applyAIRSpecializeChannelWrapAndStridePattern (
0 commit comments