Skip to content

Commit 2673851

Browse files
authored
AIROverrideMemrefMemorySpace: operating on ModuleOp (#1099)
1 parent 0e2d12f commit 2673851

File tree

5 files changed

+17
-11
lines changed

5 files changed

+17
-11
lines changed

mlir/include/air/Transform/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,7 @@ def DmaToChannel : Pass<"air-dma-to-channel", "ModuleOp"> {
15331533
}];
15341534
}
15351535

1536-
def AIROverrideMemRefMemorySpace : Pass<"air-override-memref-memory-space", "func::FuncOp"> {
1536+
def AIROverrideMemRefMemorySpace : Pass<"air-override-memref-memory-space", "ModuleOp"> {
15371537
let summary = "Force all memrefs allocated within code region to have the specified memory space.";
15381538
let constructor = "xilinx::air::createAIROverrideMemRefMemorySpacePass()";
15391539
let description = [{
@@ -1544,7 +1544,7 @@ def AIROverrideMemRefMemorySpace : Pass<"air-override-memref-memory-space", "fun
15441544
"Memory space to override to.">,
15451545
Option<"clScope", "scope", "std::string",
15461546
/*default=*/"\"launch\"",
1547-
"AIR hierarchy scope to perform the transform under. Must be one of [herd, segment, launch].">
1547+
"AIR hierarchy scope to perform the transform under. Must be one of [herd, segment, launch, func].">
15481548
];
15491549
}
15501550

mlir/lib/Transform/AIRMiscPasses.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,8 +1234,9 @@ FailureOr<Value> tileChannelOpByFactor(
12341234
Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
12351235
// Create and apply affine map onto the split channel ops.
12361236
SmallVector<Value> tokens;
1237-
int memorySpace = dyn_cast<MemRefType>(originalChanOp.getMemref().getType())
1238-
.getMemorySpaceAsInt();
1237+
int memorySpace =
1238+
dyn_cast<BaseMemRefType>(originalChanOp.getMemref().getType())
1239+
.getMemorySpaceAsInt();
12391240
for (int i = 0; i < factor; i++) {
12401241
// Get affine map and split size from splitInfo.
12411242
auto &[splitInfoDimOnOffsets, splitInfoAffineMap, splitInfoSplitOffset,
@@ -2354,9 +2355,11 @@ struct OverrideMemorySpacePattern : public OpRewritePattern<memref::AllocOp> {
23542355
parent = alloc->getParentOfType<air::SegmentOp>();
23552356
else if (clScope == "launch")
23562357
parent = alloc->getParentOfType<air::LaunchOp>();
2358+
else if (clScope == "func")
2359+
parent = alloc->getParentOfType<func::FuncOp>();
23572360
else
23582361
return alloc->emitOpError(
2359-
"Invalid clScope value: expected one of herd/segment/launch");
2362+
"Invalid clScope value: expected one of herd/segment/launch/func");
23602363

23612364
if (!parent)
23622365
return failure();
@@ -2440,12 +2443,12 @@ class AIROverrideMemRefMemorySpacePass
24402443
};
24412444

24422445
void AIROverrideMemRefMemorySpacePass::runOnOperation() {
2443-
func::FuncOp funcOp = getOperation();
2446+
auto moduleOp = getOperation();
24442447
MLIRContext *context = &getContext();
24452448

24462449
RewritePatternSet patterns(context);
24472450
patterns.add<OverrideMemorySpacePattern>(context, clScope, clMemorySpace);
2448-
(void)applyPatternsGreedily(funcOp, std::move(patterns));
2451+
(void)applyPatternsGreedily(moduleOp, std::move(patterns));
24492452
RewritePatternSet fixResTypePatterns(context);
24502453
if (clScope == "herd") {
24512454
fixResTypePatterns.add<correctViewLikeOpIOMemorySpacesInScope<air::HerdOp>>(
@@ -2456,8 +2459,11 @@ void AIROverrideMemRefMemorySpacePass::runOnOperation() {
24562459
} else if (clScope == "launch") {
24572460
fixResTypePatterns
24582461
.add<correctViewLikeOpIOMemorySpacesInScope<air::LaunchOp>>(context);
2462+
} else if (clScope == "func") {
2463+
fixResTypePatterns
2464+
.add<correctViewLikeOpIOMemorySpacesInScope<func::FuncOp>>(context);
24592465
}
2460-
(void)applyPatternsGreedily(funcOp, std::move(fixResTypePatterns));
2466+
(void)applyPatternsGreedily(moduleOp, std::move(fixResTypePatterns));
24612467
}
24622468

24632469
} // namespace xilinx

test/xrt/31_triton_blk_ptr_eltwise_mul/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
"buffer-results-to-out-params",
6565
"air-par-to-herd{depth=-1}",
6666
"air-insert-launch-around-herd{insert-segment=false}",
67-
"func.func(air-override-memref-memory-space{scope=herd memory-space=2})",
67+
"air-override-memref-memory-space{scope=herd memory-space=2}",
6868
"air-copy-to-dma",
6969
"canonicalize",
7070
"cse",

test/xrt/32_triton_matmul/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
"one-shot-bufferize",
6363
f"func.func(air-wrap-func-with-parallel{{loop-bounds={launch_size[0]},{launch_size[1]},{launch_size[2]}}})",
6464
"air-par-to-launch{depth=0 has-air-segment=true}",
65-
"func.func(air-override-memref-memory-space{scope=launch memory-space=1})",
65+
"air-override-memref-memory-space{scope=launch memory-space=1}",
6666
]
6767
)
6868
+ ")"

test/xrt/33_triton_matmul_ver2/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
"cse",
6868
"air-par-to-herd{depth=-1}",
6969
"air-insert-launch-around-herd{insert-segment=false}",
70-
"func.func(air-override-memref-memory-space{scope=herd memory-space=2})",
70+
"air-override-memref-memory-space{scope=herd memory-space=2}",
7171
"air-copy-to-dma",
7272
"canonicalize",
7373
"cse",

0 commit comments

Comments
 (0)