@@ -1234,8 +1234,9 @@ FailureOr<Value> tileChannelOpByFactor(
12341234 Value zeroIdx = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
12351235 // Create and apply affine map onto the split channel ops.
12361236 SmallVector<Value> tokens;
1237- int memorySpace = dyn_cast<MemRefType>(originalChanOp.getMemref ().getType ())
1238- .getMemorySpaceAsInt ();
1237+ int memorySpace =
1238+ dyn_cast<BaseMemRefType>(originalChanOp.getMemref ().getType ())
1239+ .getMemorySpaceAsInt ();
12391240 for (int i = 0 ; i < factor; i++) {
12401241 // Get affine map and split size from splitInfo.
12411242 auto &[splitInfoDimOnOffsets, splitInfoAffineMap, splitInfoSplitOffset,
@@ -2354,9 +2355,11 @@ struct OverrideMemorySpacePattern : public OpRewritePattern<memref::AllocOp> {
23542355 parent = alloc->getParentOfType <air::SegmentOp>();
23552356 else if (clScope == " launch" )
23562357 parent = alloc->getParentOfType <air::LaunchOp>();
2358+ else if (clScope == " func" )
2359+ parent = alloc->getParentOfType <func::FuncOp>();
23572360 else
23582361 return alloc->emitOpError (
2359- " Invalid clScope value: expected one of herd/segment/launch" );
2362+ " Invalid clScope value: expected one of herd/segment/launch/func " );
23602363
23612364 if (!parent)
23622365 return failure ();
@@ -2440,12 +2443,12 @@ class AIROverrideMemRefMemorySpacePass
24402443};
24412444
24422445void AIROverrideMemRefMemorySpacePass::runOnOperation () {
2443- func::FuncOp funcOp = getOperation ();
2446+ auto moduleOp = getOperation ();
24442447 MLIRContext *context = &getContext ();
24452448
24462449 RewritePatternSet patterns (context);
24472450 patterns.add <OverrideMemorySpacePattern>(context, clScope, clMemorySpace);
2448- (void )applyPatternsGreedily (funcOp , std::move (patterns));
2451+ (void )applyPatternsGreedily (moduleOp , std::move (patterns));
24492452 RewritePatternSet fixResTypePatterns (context);
24502453 if (clScope == " herd" ) {
24512454 fixResTypePatterns.add <correctViewLikeOpIOMemorySpacesInScope<air::HerdOp>>(
@@ -2456,8 +2459,11 @@ void AIROverrideMemRefMemorySpacePass::runOnOperation() {
24562459 } else if (clScope == " launch" ) {
24572460 fixResTypePatterns
24582461 .add <correctViewLikeOpIOMemorySpacesInScope<air::LaunchOp>>(context);
2462+ } else if (clScope == " func" ) {
2463+ fixResTypePatterns
2464+ .add <correctViewLikeOpIOMemorySpacesInScope<func::FuncOp>>(context);
24592465 }
2460- (void )applyPatternsGreedily (funcOp , std::move (fixResTypePatterns));
2466+ (void )applyPatternsGreedily (moduleOp , std::move (fixResTypePatterns));
24612467}
24622468
24632469} // namespace xilinx
0 commit comments