Skip to content

Commit afacecd

Browse files
authored
Merge branch 'main' into fixup_eltwise_add_with_l2_missing_deallocs
2 parents 6124231 + b62016b commit afacecd

File tree

7 files changed

+186
-97
lines changed

7 files changed

+186
-97
lines changed

mlir/include/air/Transform/AIRMiscPasses.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ std::unique_ptr<mlir::Pass> createAIRUnrollOuterPerfectlyNestedLoopsPass();
3535
std::unique_ptr<mlir::Pass> createAIRUnrollOuterPerfectlyNestedLoopsPass(
3636
AIRUnrollOuterPerfectlyNestedLoopsPassOptions options);
3737
std::unique_ptr<mlir::Pass> createAIRSplitL2MemrefForBufferConstraintPass();
38-
std::unique_ptr<Pass> createAIRForceL1MemrefInHerdPass();
38+
std::unique_ptr<Pass> createAIROverrideMemRefMemorySpacePass();
39+
std::unique_ptr<mlir::Pass> createAIROverrideMemRefMemorySpacePass(
40+
AIROverrideMemRefMemorySpaceOptions options);
3941

4042
} // namespace air
4143
} // namespace xilinx

mlir/include/air/Transform/PassDetail.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ namespace air {
7272
#define GEN_PASS_DEF_AIRSHRINKMEMREFSIZESBYACCESS
7373
#define GEN_PASS_DEF_AIRSPLITL2MEMREFFORBUFFERCONSTRAINTPASS
7474
#define GEN_PASS_DEF_DMATOCHANNEL
75-
#define GEN_PASS_DEF_AIRFORCEL1MEMREFINHERDPASS
75+
#define GEN_PASS_DEF_AIROVERRIDEMEMREFMEMORYSPACE
7676
#include "air/Transform/Passes.h.inc"
7777

7878
} // namespace air

mlir/include/air/Transform/Passes.td

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,12 +1503,19 @@ def DmaToChannel : Pass<"air-dma-to-channel", "ModuleOp"> {
15031503
}];
15041504
}
15051505

1506-
def AIRForceL1MemrefInHerdPass : Pass<"air-force-l1-memref-in-herd", "func::FuncOp"> {
1507-
let summary = "Force all memrefs allocated within air.herd to have memory space L1.";
1508-
let constructor = "xilinx::air::createAIRForceL1MemrefInHerdPass()";
1506+
def AIROverrideMemRefMemorySpace : Pass<"air-override-memref-memory-space", "func::FuncOp"> {
1507+
let summary = "Force all memrefs allocated within code region to have the specified memory space.";
1508+
let constructor = "xilinx::air::createAIROverrideMemRefMemorySpacePass()";
15091509
let description = [{
1510-
Experimental pass. Force all memrefs allocated within air.herd to have memory space L1.
1510+
Experimental pass. Force all memrefs allocated within a specified code region to have the specified memory space.
15111511
}];
1512+
let options = [
1513+
Option<"clMemorySpace", "memory-space", "unsigned", /*default=*/"0",
1514+
"Memory space to override to.">,
1515+
Option<"clScope", "scope", "std::string",
1516+
/*default=*/"\"launch\"",
1517+
"AIR hierarchy scope to perform the transform under. Must be one of [herd, segment, launch].">
1518+
];
15121519
}
15131520

15141521
#endif // AIR_CONVERSION_PASSES

mlir/lib/Transform/AIRMiscPasses.cpp

Lines changed: 93 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,81 +1956,128 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
19561956
air::renumberMemcpyIfOps(&func.getBody());
19571957
}
19581958

1959-
// Experimental: This pattern forces all memrefs allocated within the air.herd
1960-
// to be L1.
1961-
struct ForceL1MemrefInHerdPattern : public OpRewritePattern<memref::AllocOp> {
1962-
using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
1959+
// Experimental pattern to override the memory space of `memref.alloc`
1960+
// operations when they appear inside a specified parent scope (e.g. herd,
1961+
// segment).
1962+
struct OverrideMemorySpacePattern : public OpRewritePattern<memref::AllocOp> {
1963+
OverrideMemorySpacePattern(MLIRContext *ctx, StringRef scope, int memSpace)
1964+
: OpRewritePattern<memref::AllocOp>(ctx), clScope(scope),
1965+
clMemorySpace(memSpace) {}
19631966

19641967
LogicalResult matchAndRewrite(memref::AllocOp alloc,
19651968
PatternRewriter &rewriter) const override {
1966-
auto parentHerdOp = alloc->getParentOfType<air::HerdOp>();
1967-
if (!parentHerdOp)
1969+
Operation *parent = nullptr;
1970+
1971+
if (clScope == "herd")
1972+
parent = alloc->getParentOfType<air::HerdOp>();
1973+
else if (clScope == "segment")
1974+
parent = alloc->getParentOfType<air::SegmentOp>();
1975+
else if (clScope == "launch")
1976+
parent = alloc->getParentOfType<air::LaunchOp>();
1977+
else
1978+
return alloc->emitOpError(
1979+
"Invalid clScope value: expected one of herd/segment/launch");
1980+
1981+
if (!parent)
19681982
return failure();
19691983

1970-
auto memref = dyn_cast<MemRefType>(alloc.getMemref().getType());
1971-
if (!memref)
1984+
auto memrefTy = dyn_cast<MemRefType>(alloc.getMemref().getType());
1985+
if (!memrefTy)
19721986
return failure();
1973-
if (memref.getMemorySpaceAsInt() == (int)air::MemorySpace::L1)
1987+
if ((int)memrefTy.getMemorySpaceAsInt() == clMemorySpace)
19741988
return failure();
19751989

19761990
auto newMemrefType =
1977-
MemRefType::get(memref.getShape(), memref.getElementType(),
1978-
memref.getLayout().getAffineMap(),
1979-
rewriter.getI32IntegerAttr((int)air::MemorySpace::L1));
1991+
MemRefType::get(memrefTy.getShape(), memrefTy.getElementType(),
1992+
memrefTy.getLayout().getAffineMap(),
1993+
rewriter.getI32IntegerAttr(clMemorySpace));
19801994

19811995
rewriter.replaceOpWithNewOp<memref::AllocOp>(alloc, newMemrefType);
19821996

19831997
return success();
19841998
}
1999+
2000+
private:
2001+
StringRef clScope; // Parent operation type to match
2002+
int clMemorySpace; // Target memory space value to assign
19852003
};
1986-
struct correctMemrefSubviewIOMemorySpaces
1987-
: public OpRewritePattern<memref::SubViewOp> {
1988-
using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
19892004

1990-
LogicalResult matchAndRewrite(memref::SubViewOp subview,
1991-
PatternRewriter &rewriter) const override {
1992-
auto srcTy = dyn_cast<MemRefType>(subview.getViewSource().getType());
1993-
auto destTy = dyn_cast<MemRefType>(subview.getResult().getType());
1994-
1995-
auto subviewOutputType =
1996-
llvm::cast<MemRefType>(memref::SubViewOp::inferResultType(
1997-
srcTy, subview.getMixedOffsets(), subview.getMixedSizes(),
1998-
subview.getMixedStrides()));
1999-
if (destTy == subviewOutputType)
2000-
return failure();
2005+
// Pattern to correct memory spaces of view-like operations within a given
2006+
// scope, following the application of OverrideMemorySpacePattern.
2007+
template <typename OpTy>
2008+
struct correctViewLikeOpIOMemorySpacesInScope : public OpRewritePattern<OpTy> {
2009+
using OpRewritePattern<OpTy>::OpRewritePattern;
20012010

2002-
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
2003-
subview, subviewOutputType, subview.getViewSource(),
2004-
subview.getMixedOffsets(), subview.getMixedSizes(),
2005-
subview.getMixedStrides());
2011+
LogicalResult matchAndRewrite(OpTy IFAOp,
2012+
PatternRewriter &rewriter) const override {
20062013

2014+
if (!IFAOp->template hasTrait<OpTrait::IsIsolatedFromAbove>())
2015+
return failure();
2016+
llvm::DenseMap<ViewLikeOpInterface, SmallVector<OpResult>> viewLikeOpsToRes;
2017+
IFAOp->template walk([&](ViewLikeOpInterface viewLike) {
2018+
auto srcTy = dyn_cast<MemRefType>(viewLike.getViewSource().getType());
2019+
if (!srcTy)
2020+
return;
2021+
for (auto res : viewLike->getResults()) {
2022+
auto destTy = dyn_cast<MemRefType>(res.getType());
2023+
if (!destTy)
2024+
return;
2025+
if (srcTy.getMemorySpaceAsInt() == destTy.getMemorySpaceAsInt())
2026+
continue;
2027+
viewLikeOpsToRes[viewLike].push_back(res);
2028+
}
2029+
});
2030+
for (auto [viewLike, results] : viewLikeOpsToRes) {
2031+
for (OpResult res : results) {
2032+
auto srcTy = dyn_cast<MemRefType>(viewLike.getViewSource().getType());
2033+
auto destTy = dyn_cast<MemRefType>(res.getType());
2034+
MemRefType::Builder builder(destTy);
2035+
builder.setMemorySpace(srcTy.getMemorySpace());
2036+
rewriter.modifyOpInPlace(viewLike, [&]() { res.setType(builder); });
2037+
}
2038+
}
20072039
return success();
20082040
}
20092041
};
20102042

2011-
// An experimental pass forcing all memrefs allocated within an air.herd to have
2012-
// memory space L1.
2013-
class AIRForceL1MemrefInHerdPass
2014-
: public air::impl::AIRForceL1MemrefInHerdPassBase<
2015-
AIRForceL1MemrefInHerdPass> {
2043+
// An experimental pass forcing all memrefs allocated within a specified air
2044+
// code region to have the specified memory space.
2045+
class AIROverrideMemRefMemorySpacePass
2046+
: public air::impl::AIROverrideMemRefMemorySpaceBase<
2047+
AIROverrideMemRefMemorySpacePass> {
20162048

20172049
public:
2018-
AIRForceL1MemrefInHerdPass() = default;
2019-
AIRForceL1MemrefInHerdPass(const AIRForceL1MemrefInHerdPass &pass){};
2050+
AIROverrideMemRefMemorySpacePass() = default;
2051+
AIROverrideMemRefMemorySpacePass(
2052+
const AIROverrideMemRefMemorySpacePass &pass){};
2053+
AIROverrideMemRefMemorySpacePass(
2054+
const ::xilinx::air::AIROverrideMemRefMemorySpaceOptions &options)
2055+
: AIROverrideMemRefMemorySpaceBase(options) {}
20202056

20212057
void runOnOperation() override;
20222058

20232059
private:
20242060
};
20252061

2026-
void AIRForceL1MemrefInHerdPass::runOnOperation() {
2062+
void AIROverrideMemRefMemorySpacePass::runOnOperation() {
20272063
func::FuncOp funcOp = getOperation();
20282064
MLIRContext *context = &getContext();
20292065

20302066
RewritePatternSet patterns(context);
2031-
patterns.add<ForceL1MemrefInHerdPattern, correctMemrefSubviewIOMemorySpaces>(
2032-
context);
2067+
patterns.add<OverrideMemorySpacePattern>(context, clScope, clMemorySpace);
20332068
(void)applyPatternsGreedily(funcOp, std::move(patterns));
2069+
RewritePatternSet fixResTypePatterns(context);
2070+
if (clScope == "herd") {
2071+
fixResTypePatterns.add<correctViewLikeOpIOMemorySpacesInScope<air::HerdOp>>(
2072+
context);
2073+
} else if (clScope == "segment") {
2074+
fixResTypePatterns
2075+
.add<correctViewLikeOpIOMemorySpacesInScope<air::SegmentOp>>(context);
2076+
} else if (clScope == "launch") {
2077+
fixResTypePatterns
2078+
.add<correctViewLikeOpIOMemorySpacesInScope<air::LaunchOp>>(context);
2079+
}
2080+
(void)applyPatternsGreedily(funcOp, std::move(fixResTypePatterns));
20342081
}
20352082

20362083
} // anonymous namespace
@@ -2092,8 +2139,12 @@ std::unique_ptr<Pass> createAIRSplitL2MemrefForBufferConstraintPass() {
20922139
return std::make_unique<AIRSplitL2MemrefForBufferConstraintPass>();
20932140
}
20942141

2095-
std::unique_ptr<Pass> createAIRForceL1MemrefInHerdPass() {
2096-
return std::make_unique<AIRForceL1MemrefInHerdPass>();
2142+
std::unique_ptr<Pass> createAIROverrideMemRefMemorySpacePass() {
2143+
return std::make_unique<AIROverrideMemRefMemorySpacePass>();
2144+
}
2145+
std::unique_ptr<Pass> createAIROverrideMemRefMemorySpacePass(
2146+
AIROverrideMemRefMemorySpaceOptions options) {
2147+
return std::make_unique<AIROverrideMemRefMemorySpacePass>(options);
20972148
}
20982149

20992150
} // namespace air

mlir/test/Transform/AIRMiscPasses/air_force_l1_memref_in_herd.mlir

Lines changed: 0 additions & 48 deletions
This file was deleted.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===- air_override_memref_memory_space.mlir -------------------*- MLIR -*-===//
2+
//
3+
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
// RUN: air-opt %s -air-override-memref-memory-space="scope=herd memory-space=2" | FileCheck %s
9+
// RUN: air-opt %s -air-override-memref-memory-space="scope=launch memory-space=2" | FileCheck %s --check-prefix=LAUNCH
10+
// RUN: air-opt %s -air-override-memref-memory-space="scope=launch memory-space=1" | FileCheck %s --check-prefix=MS1
11+
12+
module {
13+
14+
// CHECK-LABEL: func.func @func0
15+
// CHECK: memref.alloc() : memref<32x64xf32, 2 : i32>
16+
// LAUNCH-LABEL: func.func @func0
17+
// LAUNCH: memref.alloc() : memref<32x64xf32, 2 : i32>
18+
// MS1-LABEL: func.func @func0
19+
// MS1: memref.alloc() : memref<32x64xf32, 1 : i32>
20+
21+
func.func @func0(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
22+
%c2 = arith.constant 2 : index
23+
%c1 = arith.constant 1 : index
24+
%c1_0 = arith.constant 1 : index
25+
air.launch (%arg9, %arg10) in (%arg11=%c1, %arg12=%c1_0) args(%arg13=%arg0, %arg14=%arg1, %arg15=%arg2) : memref<*xf32>, memref<*xf32>, memref<*xf32> {
26+
air.segment @bare_matmul_0 args(%arg16=%arg9, %arg17=%arg10, %arg18=%arg11, %arg19=%arg12, %arg20=%arg13, %arg21=%arg14, %arg22=%arg15) : index, index, index, index, memref<*xf32>, memref<*xf32>, memref<*xf32> {
27+
%c2_1 = arith.constant 2 : index
28+
%c2_2 = arith.constant 2 : index
29+
air.herd @herd_0 tile (%arg23, %arg24) in (%arg25=%c2_1, %arg26=%c2_2) args(%arg27=%arg16, %arg28=%arg17, %arg29=%arg18, %arg30=%arg19, %arg31=%arg20, %arg32=%arg21, %arg33=%arg22) : index, index, index, index, memref<*xf32>, memref<*xf32>, memref<*xf32> {
30+
%c32 = arith.constant 32 : index
31+
%c2048 = arith.constant 2048 : index
32+
%c64 = arith.constant 64 : index
33+
%1 = arith.muli %arg23, %c2048 : index
34+
%reinterpret_cast = memref.reinterpret_cast %arg31 to offset: [%1], sizes: [32, 64], strides: [%c64, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
35+
%alloc = memref.alloc() : memref<32x64xf32>
36+
memref.copy %reinterpret_cast, %alloc : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<32x64xf32>
37+
}
38+
}
39+
}
40+
return
41+
}
42+
43+
// LAUNCH-LABEL: func.func @func1
44+
// LAUNCH: memref.alloc() : memref<8x4x4x8xf32, 2 : i32>
45+
// LAUNCH: memref.collapse_shape {{.*}} : memref<8x4x4x8xf32, 2 : i32> into memref<32x32xf32, 2 : i32>
46+
// LAUNCH: memref.alloc() : memref<4x8x8x4xf32, 2 : i32>
47+
// LAUNCH: memref.collapse_shape {{.*}} : memref<4x8x8x4xf32, 2 : i32> into memref<32x32xf32, 2 : i32>
48+
// LAUNCH: memref.alloc() : memref<32x32xf32, 2 : i32>
49+
// LAUNCH: memref.expand_shape {{.*}} : memref<32x32xf32, 2 : i32> into memref<8x4x8x4xf32, 2 : i32>
50+
// LAUNCH: memref.alloc() : memref<8x8x4x4xf32, 2 : i32>
51+
// MS1-LABEL: func.func @func1
52+
// MS1: memref.alloc() : memref<8x4x4x8xf32, 1 : i32>
53+
// MS1: memref.collapse_shape {{.*}} : memref<8x4x4x8xf32, 1 : i32> into memref<32x32xf32, 1 : i32>
54+
// MS1: memref.alloc() : memref<4x8x8x4xf32, 1 : i32>
55+
// MS1: memref.collapse_shape {{.*}} : memref<4x8x8x4xf32, 1 : i32> into memref<32x32xf32, 1 : i32>
56+
// MS1: memref.alloc() : memref<32x32xf32, 1 : i32>
57+
// MS1: memref.expand_shape {{.*}} : memref<32x32xf32, 1 : i32> into memref<8x4x8x4xf32, 1 : i32>
58+
// MS1: memref.alloc() : memref<8x8x4x4xf32, 1 : i32>
59+
60+
func.func @func1(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
61+
%c1 = arith.constant 1 : index
62+
%c2 = arith.constant 2 : index
63+
air.launch (%arg9, %arg10, %arg11) in (%arg12=%c2, %arg13=%c2, %arg14=%c1) args(%arg15=%arg0, %arg16=%arg1, %arg17=%arg2) : memref<*xf32>, memref<*xf32>, memref<*xf32> {
64+
%cst = arith.constant 0.000000e+00 : f32
65+
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<8x4x4x8xf32>
66+
%collapse_shape = memref.collapse_shape %alloc_2 [[0, 1], [2, 3]] : memref<8x4x4x8xf32> into memref<32x32xf32>
67+
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<4x8x8x4xf32>
68+
%collapse_shape_4 = memref.collapse_shape %alloc_3 [[0, 1], [2, 3]] : memref<4x8x8x4xf32> into memref<32x32xf32>
69+
%alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
70+
linalg.matmul ins(%collapse_shape, %collapse_shape_4 : memref<32x32xf32>, memref<32x32xf32>) outs(%alloc_5 : memref<32x32xf32>)
71+
%expand_shape = memref.expand_shape %alloc_5 [[0, 1], [2, 3]] output_shape [8, 4, 8, 4] : memref<32x32xf32> into memref<8x4x8x4xf32>
72+
%alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<8x8x4x4xf32>
73+
linalg.transpose ins(%expand_shape : memref<8x4x8x4xf32>) outs(%alloc_6 : memref<8x8x4x4xf32>) permutation = [0, 2, 1, 3]
74+
}
75+
return
76+
}
77+
}

test/xrt/31_triton_blk_ptr_eltwise_mul/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
"buffer-results-to-out-params",
6565
"air-par-to-herd{depth=-1}",
6666
"air-insert-launch-around-herd{insert-segment=false}",
67-
"func.func(air-force-l1-memref-in-herd)",
67+
"func.func(air-override-memref-memory-space{scope=herd memory-space=2})",
6868
"air-copy-to-dma",
6969
"canonicalize",
7070
"cse",

0 commit comments

Comments
 (0)