Skip to content

Commit 53e3f44

Browse files
authored
AIRToAIE: Lowering cascade air.channel.put/get to aie.put/get_cascade ops and vector.transfer_read/write ops (Xilinx#1055)
1 parent 5cbc6ac commit 53e3f44

File tree

11 files changed

+743
-89
lines changed

11 files changed

+743
-89
lines changed

mlir/include/air/Conversion/AIRToAIESchedulingUtils.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ struct MemcpyBundleAsFlow {
9999
int S2MM_memspace_as_int;
100100
int numMM2SAllocs = 0;
101101
int numS2MMAllocs = 0;
102+
std::string
103+
memcpyResourceType; // The type of mechanism used for the memcpy op,
104+
// including dma_stream, dma_packet, and cascade.
102105
LogicalResult pushBackMemcpyOpToBundle(air::DmaMemcpyNdOp memcpyOp);
103106
LogicalResult pushBackMemcpyOpToBundle(air::ChannelGetOp memcpyOp);
104107
LogicalResult pushBackMemcpyOpToBundle(air::ChannelPutOp memcpyOp);
@@ -200,11 +203,35 @@ class MemTileDMAAllocator : public DMAAllocator {
200203
air::allocation_info_t alloc, bool isMM2S);
201204
};
202205

206+
class CascadeAllocator {
207+
208+
public:
209+
CascadeAllocator() = delete;
210+
// CascadeAllocator constructor: only core-to-core (L1-level) cascade
211+
// connection supported.
212+
CascadeAllocator(AIE::DeviceOp device)
213+
: device(device), DMAMemorySpaceAsInt((int)air::MemorySpace::L1) {}
214+
FailureOr<allocation_info_t> coreCascadeAlloc(air::MemcpyInterface &memcpyOp);
215+
FailureOr<allocation_info_t> allocNewCascade(air::MemcpyInterface &memcpyOp,
216+
AIE::TileOp tile);
217+
218+
FailureOr<AIE::BufferOp> getBuffer(uint64_t, int64_t col, int64_t row,
219+
air::MemcpyInterface &memcpyOp);
220+
221+
protected:
222+
AIE::DeviceOp device;
223+
int DMAMemorySpaceAsInt;
224+
225+
public:
226+
std::vector<allocation_info_t> cascade_put_allocs, cascade_get_allocs;
227+
};
228+
203229
LogicalResult
204230
simpleDMAChannelAllocation(std::vector<MemcpyBundleAsFlow> &memcpy_flows,
205231
ShimDMAAllocator &shim_dma_alloc,
206232
MemTileDMAAllocator &memtile_dma_alloc,
207-
TileDMAAllocator &tile_dma_alloc);
233+
TileDMAAllocator &tile_dma_alloc,
234+
air::CascadeAllocator &core_cascade_alloc);
208235
template <typename T>
209236
int foundInVector(T item, std::vector<T> vec);
210237
int getSCFForLoopDepth(Operation *o);

mlir/lib/Conversion/AIRToAIEPass.cpp

Lines changed: 190 additions & 62 deletions
Large diffs are not rendered by default.

mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp

Lines changed: 177 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,28 @@ air::getLockValuePair(const AIE::AIETargetModel &targetModel,
398398
unique_write_buffers.size());
399399
}
400400

401+
// Helper function that tries to retrieve the underlying AIE::BufferOp by
402+
// unwrapping common memref wrappers (cast or subview)
403+
AIE::BufferOp getUnderlyingBufferOp(Value buffer) {
404+
// Case 1: Directly defined by an AIE::BufferOp
405+
if (auto bufferOp = buffer.getDefiningOp<AIE::BufferOp>())
406+
return bufferOp;
407+
408+
// Case 2: Defined by a cast (e.g., memref.cast)
409+
if (auto castOp = buffer.getDefiningOp<CastOpInterface>())
410+
if (auto innerBuffer = castOp->getOperand(0).getDefiningOp<AIE::BufferOp>())
411+
return innerBuffer;
412+
413+
// Case 3: Defined by a view-like op (e.g., memref.subview)
414+
if (auto viewLikeOp = buffer.getDefiningOp<ViewLikeOpInterface>())
415+
if (auto innerBuffer =
416+
viewLikeOp->getOperand(0).getDefiningOp<AIE::BufferOp>())
417+
return innerBuffer;
418+
419+
// No underlying BufferOp found
420+
return nullptr;
421+
}
422+
401423
// allocation_info_t impl.
402424

403425
bool xilinx::air::allocation_info_t::valid() { return dma_tile != nullptr; }
@@ -745,12 +767,7 @@ air::TileDMAAllocator::getBuffer(uint64_t, int64_t col, int64_t row,
745767
Value buffer = isTileInbound(memcpyOp, DMAMemorySpaceAsInt).value()
746768
? (memcpyOp.getDstMemref())
747769
: (memcpyOp.getSrcMemref());
748-
AIE::BufferOp bufferOp = buffer.getDefiningOp<AIE::BufferOp>();
749-
// Memref cast
750-
memref::CastOp castOp = buffer.getDefiningOp<memref::CastOp>();
751-
if (!bufferOp && castOp)
752-
bufferOp = castOp.getOperand().getDefiningOp<AIE::BufferOp>();
753-
return bufferOp;
770+
return getUnderlyingBufferOp(buffer);
754771
}
755772

756773
// ShimDMAAllocator impl.
@@ -1040,12 +1057,102 @@ air::MemTileDMAAllocator::getBuffer(uint64_t, int64_t col, int64_t row,
10401057
Value buffer = isTileInbound(memcpyOp, DMAMemorySpaceAsInt).value()
10411058
? (memcpyOp.getDstMemref())
10421059
: (memcpyOp.getSrcMemref());
1043-
AIE::BufferOp bufferOp = buffer.getDefiningOp<AIE::BufferOp>();
1044-
// Memref cast
1045-
memref::CastOp castOp = buffer.getDefiningOp<memref::CastOp>();
1046-
if (!bufferOp && castOp)
1047-
bufferOp = castOp.getOperand().getDefiningOp<AIE::BufferOp>();
1048-
return bufferOp;
1060+
return getUnderlyingBufferOp(buffer);
1061+
}
1062+
1063+
// CascadeAllocator impl.
1064+
1065+
// Attempts to allocate (or reuse) a cascade flow for the given memcpyOp.
1066+
FailureOr<air::allocation_info_t>
1067+
air::CascadeAllocator::coreCascadeAlloc(air::MemcpyInterface &memcpyOp) {
1068+
// Determine if the operation is a cascade put (outbound)
1069+
auto isCascadePut = isTileOutbound(memcpyOp, DMAMemorySpaceAsInt);
1070+
if (failed(isCascadePut))
1071+
return failure();
1072+
1073+
// Select allocation list based on direction
1074+
auto allocs =
1075+
isCascadePut.value() ? &cascade_put_allocs : &cascade_get_allocs;
1076+
1077+
// Retrieve the buffer and the tile where this memcpyOp operates
1078+
const int dummy{0};
1079+
auto buffer = getBuffer(dummy, -1, -1, memcpyOp);
1080+
if (failed(buffer)) {
1081+
return memcpyOp->emitOpError("failed to get buffer.");
1082+
}
1083+
auto tile = buffer.value().getTileOp();
1084+
if (!tile) {
1085+
return buffer.value()->emitOpError("failed to get AIE tile.");
1086+
}
1087+
1088+
// Search for an existing allocation for this tile and memcpyOp
1089+
for (auto &t : *allocs) {
1090+
if (t.foundAlloc(tile.getCol(), tile.getRow(), memcpyOp))
1091+
return t;
1092+
}
1093+
1094+
// No existing allocation found, create a new one
1095+
return air::CascadeAllocator::allocNewCascade(memcpyOp, tile);
1096+
}
1097+
1098+
// Creates a new cascade allocation entry when no matching allocation exists.
1099+
FailureOr<air::allocation_info_t>
1100+
air::CascadeAllocator::allocNewCascade(air::MemcpyInterface &memcpyOp,
1101+
AIE::TileOp tile) {
1102+
if (!tile) {
1103+
return memcpyOp.emitOpError("failed to get the AIE tile. This indicates a "
1104+
"potential error in the compilation flow.");
1105+
}
1106+
1107+
// Determine if this is a cascade put or get
1108+
auto isCascadePut = isTileOutbound(memcpyOp, DMAMemorySpaceAsInt);
1109+
if (failed(isCascadePut))
1110+
return failure();
1111+
auto allocs =
1112+
isCascadePut.value() ? &cascade_put_allocs : &cascade_get_allocs;
1113+
1114+
// Check if allocation already exists for this tile
1115+
for (auto &t : *allocs) {
1116+
if (t.foundAlloc(tile.getCol(), tile.getRow())) {
1117+
t.memcpyOps.push_back(memcpyOp.getOperation());
1118+
return t;
1119+
}
1120+
// Also check for an allocation tied to the channel declaration
1121+
if (t.foundAlloc(
1122+
tile.getCol(), tile.getRow(),
1123+
getChannelDeclarationThroughSymbol(
1124+
dyn_cast<air::ChannelInterface>(memcpyOp.getOperation())))) {
1125+
t.memcpyOps.push_back(memcpyOp.getOperation());
1126+
return t;
1127+
}
1128+
}
1129+
1130+
// Create a new allocation_info_t entry for this tile
1131+
air::allocation_info_t output = {tile,
1132+
/*col*/ -1,
1133+
/*row*/ -1,
1134+
/*aie_chan*/ AIE::DMAChannel(),
1135+
/*chan*/ -1,
1136+
/*dma_id*/ std::vector<int>{},
1137+
{memcpyOp.getOperation()}};
1138+
allocs->push_back(output);
1139+
return output;
1140+
}
1141+
1142+
// Retrieves the underlying AIE::BufferOp associated with the given memcpyOp.
1143+
FailureOr<AIE::BufferOp>
1144+
air::CascadeAllocator::getBuffer(uint64_t, int64_t col, int64_t row,
1145+
air::MemcpyInterface &memcpyOp) {
1146+
if (failed(isTileInbound(memcpyOp, DMAMemorySpaceAsInt)))
1147+
return failure();
1148+
1149+
// Select source or destination buffer depending on inbound/outbound
1150+
Value buffer = isTileInbound(memcpyOp, DMAMemorySpaceAsInt).value()
1151+
? (memcpyOp.getDstMemref())
1152+
: (memcpyOp.getSrcMemref());
1153+
1154+
// Resolve the actual underlying buffer op
1155+
return getUnderlyingBufferOp(buffer);
10491156
}
10501157

10511158
// MemcpyBundleAsFlow impl.
@@ -1061,6 +1168,7 @@ air::MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::DmaMemcpyNdOp memcpyOp) {
10611168
MM2S_memspace_as_int =
10621169
llvm::cast<BaseMemRefType>(memcpyOp.getSrcMemref().getType())
10631170
.getMemorySpaceAsInt();
1171+
memcpyResourceType = "dma_stream";
10641172
return success();
10651173
}
10661174

@@ -1095,6 +1203,7 @@ air::MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::ChannelGetOp memcpyOp) {
10951203
S2MM_memspace_as_int =
10961204
llvm::cast<BaseMemRefType>(memcpyOp.getMemref().getType())
10971205
.getMemorySpaceAsInt();
1206+
memcpyResourceType = chan.getChannelType().str();
10981207
return success();
10991208
}
11001209

@@ -1106,6 +1215,7 @@ air::MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::ChannelPutOp memcpyOp) {
11061215
MM2S_memspace_as_int =
11071216
llvm::cast<BaseMemRefType>(memcpyOp.getMemref().getType())
11081217
.getMemorySpaceAsInt();
1218+
memcpyResourceType = chan.getChannelType().str();
11091219
return success();
11101220
}
11111221

@@ -1128,6 +1238,7 @@ air::MemcpyBundleAsFlow::MemcpyBundleAsFlow(air::DmaMemcpyNdOp dmaMemcpyOp) {
11281238
std::vector<Operation *>());
11291239
S2MM = v1;
11301240
S2MM_alloc = std::vector<air::allocation_info_t>(numS2MMAllocs);
1241+
memcpyResourceType = "dma_stream";
11311242
}
11321243

11331244
air::MemcpyBundleAsFlow::MemcpyBundleAsFlow(air::ChannelOp chan) {
@@ -1146,6 +1257,7 @@ air::MemcpyBundleAsFlow::MemcpyBundleAsFlow(air::ChannelOp chan) {
11461257
std::vector<Operation *>());
11471258
S2MM = v1;
11481259
S2MM_alloc = std::vector<air::allocation_info_t>(numS2MMAllocs);
1260+
memcpyResourceType = chan.getChannelType().str();
11491261
}
11501262

11511263
} // namespace xilinx
@@ -1158,7 +1270,8 @@ LogicalResult air::simpleDMAChannelAllocation(
11581270
std::vector<air::MemcpyBundleAsFlow> &memcpy_flows,
11591271
air::ShimDMAAllocator &shim_dma_alloc,
11601272
air::MemTileDMAAllocator &memtile_dma_alloc,
1161-
TileDMAAllocator &tile_dma_alloc) {
1273+
TileDMAAllocator &tile_dma_alloc,
1274+
air::CascadeAllocator &core_cascade_alloc) {
11621275
for (auto &f : memcpy_flows) {
11631276
if (f.MM2S_memspace_as_int == (int)air::MemorySpace::L1) {
11641277
for (auto o : f.MM2S) {
@@ -1172,10 +1285,18 @@ LogicalResult air::simpleDMAChannelAllocation(
11721285
int x = tile.getCol();
11731286
int y = tile.getRow();
11741287

1175-
auto alloc_res = tile_dma_alloc.simpleDmaChannelAlloc(
1176-
memcpyOpIf, x, y, f.MM2S_alloc.dma_channel.channel);
1177-
if (failed(alloc_res))
1178-
return failure();
1288+
FailureOr<air::allocation_info_t> alloc_res;
1289+
if (f.memcpyResourceType == "dma_stream" ||
1290+
f.memcpyResourceType == "dma_packet") {
1291+
alloc_res = tile_dma_alloc.simpleDmaChannelAlloc(
1292+
memcpyOpIf, x, y, f.MM2S_alloc.dma_channel.channel);
1293+
if (failed(alloc_res))
1294+
return failure();
1295+
} else if (f.memcpyResourceType == "cascade") {
1296+
alloc_res = core_cascade_alloc.coreCascadeAlloc(memcpyOpIf);
1297+
if (failed(alloc_res))
1298+
return failure();
1299+
}
11791300

11801301
f.MM2S_alloc = alloc_res.value();
11811302
if (!f.MM2S_alloc.valid())
@@ -1195,10 +1316,19 @@ LogicalResult air::simpleDMAChannelAllocation(
11951316
int x = tile.getCol();
11961317
int y = tile.getRow();
11971318

1198-
auto alloc_res = tile_dma_alloc.simpleDmaChannelAlloc(
1199-
memcpyOpIf, x, y, f.S2MM_alloc[i].dma_channel.channel);
1200-
if (failed(alloc_res))
1201-
return failure();
1319+
FailureOr<air::allocation_info_t> alloc_res;
1320+
if (f.memcpyResourceType == "dma_stream" ||
1321+
f.memcpyResourceType == "dma_packet") {
1322+
alloc_res = tile_dma_alloc.simpleDmaChannelAlloc(
1323+
memcpyOpIf, x, y, f.S2MM_alloc[i].dma_channel.channel);
1324+
if (failed(alloc_res))
1325+
return failure();
1326+
} else if (f.memcpyResourceType == "cascade") {
1327+
alloc_res = core_cascade_alloc.coreCascadeAlloc(memcpyOpIf);
1328+
if (failed(alloc_res))
1329+
return failure();
1330+
}
1331+
12021332
f.S2MM_alloc[i] = alloc_res.value();
12031333
if (!f.S2MM_alloc[i].valid())
12041334
return failure();
@@ -1210,6 +1340,12 @@ LogicalResult air::simpleDMAChannelAllocation(
12101340
if (f.MM2S_memspace_as_int == (int)air::MemorySpace::L2) {
12111341
for (auto o : f.MM2S) {
12121342
auto memcpyOpIf = cast<air::MemcpyInterface>(o);
1343+
// Report error if the data movement lowers to neither dma stream
1344+
// (aie.flow) nor dma packet flow (aie.packet_flow).
1345+
if (f.memcpyResourceType != "dma_stream" &&
1346+
f.memcpyResourceType != "dma_packet")
1347+
return memcpyOpIf->emitOpError("only supports dma_stream or "
1348+
"dma_packet connections at L2 memory");
12131349
auto alloc_res = memtile_dma_alloc.simpleDmaChannelAlloc(memcpyOpIf);
12141350
if (failed(alloc_res) || !alloc_res->valid())
12151351
return failure();
@@ -1220,6 +1356,13 @@ LogicalResult air::simpleDMAChannelAllocation(
12201356
for (size_t i = 0; i < f.S2MM.size(); i++) {
12211357
for (auto o : f.S2MM[i]) {
12221358
auto memcpyOpIf = cast<air::MemcpyInterface>(o);
1359+
// Report error if the data movement lowers to neither dma stream
1360+
// (aie.flow) nor dma packet flow (aie.packet_flow).
1361+
if (f.memcpyResourceType != "dma_stream" &&
1362+
f.memcpyResourceType != "dma_packet")
1363+
return memcpyOpIf->emitOpError(
1364+
"only supports dma_stream or dma_packet connections at L2 "
1365+
"memory");
12231366
auto alloc_res = memtile_dma_alloc.simpleDmaChannelAlloc(memcpyOpIf);
12241367
if (failed(alloc_res) || !alloc_res->valid())
12251368
return failure();
@@ -1233,6 +1376,13 @@ LogicalResult air::simpleDMAChannelAllocation(
12331376
for (size_t i = 0; i < f.S2MM.size(); i++) {
12341377
for (auto o : f.MM2S) {
12351378
auto memcpyOpIf = cast<air::MemcpyInterface>(o);
1379+
// Report error if the data movement lowers to neither dma stream
1380+
// (aie.flow) nor dma packet flow (aie.packet_flow).
1381+
if (f.memcpyResourceType != "dma_stream" &&
1382+
f.memcpyResourceType != "dma_packet")
1383+
return memcpyOpIf->emitOpError(
1384+
"only supports dma_stream or dma_packet connections at L3 "
1385+
"memory");
12361386
auto alloc_res = shim_dma_alloc.allocNewDmaChannel(
12371387
memcpyOpIf, f.S2MM_alloc[i].getDmaTile().getCol(),
12381388
f.S2MM_alloc[i].getDmaTile().getRow(), f.S2MM[i]);
@@ -1251,6 +1401,12 @@ LogicalResult air::simpleDMAChannelAllocation(
12511401
}
12521402
for (auto o : f.S2MM.front()) {
12531403
auto memcpyOpIf = cast<air::MemcpyInterface>(o);
1404+
// Report error if the data movement lowers to neither dma stream
1405+
// (aie.flow) nor dma packet flow (aie.packet_flow).
1406+
if (f.memcpyResourceType != "dma_stream" &&
1407+
f.memcpyResourceType != "dma_packet")
1408+
return memcpyOpIf->emitOpError("only supports dma_stream or "
1409+
"dma_packet connections at L3 memory");
12541410
auto alloc_res = shim_dma_alloc.allocNewDmaChannel(
12551411
memcpyOpIf, f.MM2S_alloc.getDmaTile().getCol(),
12561412
f.MM2S_alloc.getDmaTile().getRow(), f.MM2S);

mlir/lib/Dialect/AIR/IR/AIRDialect.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,8 +1704,42 @@ template <typename OpT>
17041704
static LogicalResult ComposeMemrefOpOnChannelOp(OpT op,
17051705
PatternRewriter &rewriter) {
17061706

1707+
// Lambda version of `getChannelDeclarationThroughSymbol` method defined in
1708+
// `Util/Utils.cpp`. It is duplicated here because `Util/Utils.cpp` depends on
1709+
// this file, so direct inclusion is not possible.
1710+
auto getChannelDeclarationThroughSymbol = [](air::ChannelInterface op) {
1711+
if (!op)
1712+
// Return an empty ChannelOp if the input operation is invalid.
1713+
return air::ChannelOp();
1714+
1715+
// Traverse up through the operation's parents until a symbol table is
1716+
// found.
1717+
Operation *parent = op;
1718+
while ((parent = parent->getParentOp())) {
1719+
if (parent->hasTrait<OpTrait::SymbolTable>()) {
1720+
auto st = mlir::SymbolTable::lookupSymbolIn(parent, op.getChanName());
1721+
if (auto chanOp = dyn_cast_if_present<air::ChannelOp>(st))
1722+
return chanOp;
1723+
}
1724+
}
1725+
1726+
// No matching channel declaration found in any enclosing symbol tables.
1727+
return air::ChannelOp();
1728+
};
1729+
1730+
// Extract the memref operand from the operation.
17071731
Value memref = op.getMemref();
17081732
if (!memref)
1733+
// If there is no associated memref, signal a failure.
1734+
return failure();
1735+
// Resolve the channel declaration for the given channel interface operation.
1736+
air::ChannelOp chan = getChannelDeclarationThroughSymbol(op);
1737+
if (!chan)
1738+
// If the channel declaration cannot be resolved, signal a failure.
1739+
return failure();
1740+
// If the channel is of type "cascade", multi-dimensional affine map access
1741+
// pattern is not supported, so skip it.
1742+
if (chan.getChannelType() == "cascade")
17091743
return failure();
17101744

17111745
// Init. memref type and offsets from memref's defining op's input type

0 commit comments

Comments
 (0)