88
99#include < numeric>
1010
11+ #include " iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
1112#include " llvm/ADT/DenseMap.h"
1213#include " llvm/Support/Debug.h"
1314#include " mlir/Dialect/Affine/IR/AffineOps.h"
2122
2223namespace mlir ::iree_compiler::AMDAIE {
2324
25+ // / Utility to create a new logical objectfifo based on shape defined by
26+ // / `newSizesOpFoldResultArr`.
27+ static AMDAIE::LogicalObjectFifoFromMemrefOp createNewLogicalObjectFifo (
28+ IRRewriter &rewriter,
29+ AMDAIE::LogicalObjectFifoFromMemrefOp &oldLogicalObjectFifo,
30+ SmallVectorImpl<OpFoldResult> &newSizesOpFoldResultArr) {
31+ OpBuilder::InsertionGuard guard (rewriter);
32+ SmallVector<int64_t > newSizes = llvm::map_to_vector (
33+ newSizesOpFoldResultArr,
34+ [](OpFoldResult sizeVal) { return getConstantIndexOrAssert (sizeVal); });
35+ Value oldAllocOp = oldLogicalObjectFifo.getMemref ();
36+ auto oldMemRefType = cast<MemRefType>(oldAllocOp.getType ());
37+ MemRefType newAllocType = MemRefType::get (
38+ newSizes, oldMemRefType.getElementType (), MemRefLayoutAttrInterface{},
39+ oldMemRefType.getMemorySpace ());
40+ assert (oldAllocOp.getDefiningOp () && " expected a defining op for the value" );
41+ rewriter.setInsertionPoint (oldAllocOp.getDefiningOp ());
42+ auto newAllocOp =
43+ rewriter.create <memref::AllocOp>(rewriter.getUnknownLoc (), newAllocType);
44+ auto newDeallocOp =
45+ rewriter.create <memref::DeallocOp>(rewriter.getUnknownLoc (), newAllocOp);
46+ newDeallocOp->moveBefore (&newAllocOp->getBlock ()->back ());
47+ auto type = cast<MemRefType>(newAllocOp.getType ());
48+ // Create new logical objectfifo.
49+ rewriter.setInsertionPoint (oldLogicalObjectFifo);
50+ auto newLogicalObjectFifo =
51+ rewriter.create <AMDAIE::LogicalObjectFifoFromMemrefOp>(
52+ rewriter.getUnknownLoc (), LogicalObjectFifoType::get (type),
53+ newAllocOp.getResult (), oldLogicalObjectFifo.getTiles ());
54+ return newLogicalObjectFifo;
55+ }
56+
57+ // / Utility to help fetch those input DmaCpyNd Ops which needs to be split.
58+ SmallVector<AMDAIE::DmaCpyNdOp> fetchDmaCpyNdOpsToSplitOrCombine (
59+ Operation *op) {
60+ SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps;
61+ // We are currently walking through CoreOps gathering 3rd Input DmaOp (if
62+ // applicable) from them.
63+ // TODO(avarma): We will generalize this later.
64+ op->walk ([&](AMDAIE::CoreOp coreOp) {
65+ SmallVector<Value> inputDmas = coreOp.getInputDmas ();
66+ if (inputDmas.size () != 3 ) return WalkResult::skip ();
67+ auto dmaCpyNdOp = inputDmas[2 ].getDefiningOp <AMDAIE::DmaCpyNdOp>();
68+ assert (dmaCpyNdOp && " expected an amdaie.dma_cpy_nd op" );
69+ l2ToL1DmaOps.push_back (dmaCpyNdOp);
70+ return WalkResult::advance ();
71+ });
72+ return l2ToL1DmaOps;
73+ }
74+
2475// / Utility to verify that the split dimensions for L2 are contiguous.
2576static LogicalResult checkIsRangeFromZero (
2677 SmallVector<size_t > &splitDimsSetForL2) {
@@ -124,6 +175,33 @@ static FailureOr<OpFoldResult> updateL3SourceOffset(IRRewriter &rewriter,
124175 return newL3AsSourceOffset;
125176}
126177
178+ // / Given a L2->L1 DmaCpyNd op, find the unique L3->L2 DmaCpyNd op.
179+ static FailureOr<AMDAIE::DmaCpyNdOp> fetchL3ToL2DmaCpyNdOp (
180+ AMDAIE::DmaCpyNdOp l2ToL1DmaOp) {
181+ LogicalObjectFifoFromMemrefOp sourceObjectFifo =
182+ l2ToL1DmaOp.getSourceObjectFifo ();
183+ SmallVector<AMDAIE::DmaCpyNdOp> l3ToL2DmaOps;
184+ AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
185+ for (Operation *objFifoUserOp : sourceObjectFifo->getUsers ()) {
186+ if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
187+ dmaOp.getTargetObjectFifo () == sourceObjectFifo) {
188+ l3ToL2DmaOps.push_back (dmaOp);
189+ }
190+ }
191+ if (l3ToL2DmaOps.size () == 0 ) {
192+ LLVM_DEBUG (llvm::dbgs () << " no corresponding L3->L2 dma op found for "
193+ << sourceObjectFifo << " \n " );
194+ return failure ();
195+ }
196+ if (l3ToL2DmaOps.size () > 1 ) {
197+ LLVM_DEBUG (llvm::dbgs () << " found more than one L3->L2 dma ops for "
198+ << sourceObjectFifo << " \n " );
199+ return failure ();
200+ }
201+ l3ToL2DmaOp = l3ToL2DmaOps[0 ];
202+ return l3ToL2DmaOp;
203+ }
204+
127205// / A struct utility to encapsulate all the data required to perform splitting
128206// / of logicalobjectfifos.
129207struct SplittingLogicalObjectFifoData {
@@ -186,25 +264,10 @@ static LogicalResult checkWhetherSplitIsPossible(
186264 }
187265
188266 // Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target.
189- SmallVector<AMDAIE::DmaCpyNdOp> l3ToL2DmaOps;
190- AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
191- for (Operation *objFifoUserOp : sourceObjectFifo->getUsers ()) {
192- if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
193- dmaOp.getTargetObjectFifo () == sourceObjectFifo) {
194- l3ToL2DmaOps.push_back (dmaOp);
195- }
196- }
197- if (l3ToL2DmaOps.size () == 0 ) {
198- LLVM_DEBUG (llvm::dbgs () << " no corresponding L3->L2 dma op found for "
199- << sourceObjectFifo << " \n " );
200- return failure ();
201- }
202- if (l3ToL2DmaOps.size () > 1 ) {
203- LLVM_DEBUG (llvm::dbgs () << " found more than one L3->L2 dma ops for "
204- << sourceObjectFifo << " \n " );
205- return failure ();
206- }
207- l3ToL2DmaOp = l3ToL2DmaOps[0 ];
267+ FailureOr<AMDAIE::DmaCpyNdOp> maybeL3ToL2DmaOp =
268+ fetchL3ToL2DmaCpyNdOp (l2ToL1DmaOps[0 ]);
269+ if (failed (maybeL3ToL2DmaOp)) return failure ();
270+ AMDAIE::DmaCpyNdOp l3ToL2DmaOp = maybeL3ToL2DmaOp.value ();
208271 if ((l3ToL2DmaOp.getTargetMixedOffsets ().size () !=
209272 l3ToL2DmaOp.getSourceMixedOffsets ().size ()) ||
210273 (l3ToL2DmaOp.getTargetMixedSizes ().size () !=
@@ -293,9 +356,6 @@ LogicalResult splitLogicalObjectFifos(
293356 l3ToL2DmaOp.getTargetMixedOffsets ();
294357 SmallVector<OpFoldResult, 4 > staticL2AsTargetSizes =
295358 l3ToL2DmaOp.getTargetMixedSizes ();
296- SmallVector<int64_t , 4 > l2ShapeAsTarget = llvm::to_vector (
297- cast<MemRefType>(l3ToL2DmaOp.getTargetObjectFifo ().getMemref ().getType ())
298- .getShape ());
299359 SmallVector<OpFoldResult, 4 > staticL3AsSourceOffsets =
300360 l3ToL2DmaOp.getSourceMixedOffsets ();
301361 SmallVector<OpFoldResult, 4 > staticL3AsSourceSizes =
@@ -310,7 +370,6 @@ LogicalResult splitLogicalObjectFifos(
310370 staticL2AsTargetSizes[dim] = oneVal;
311371 staticL3AsSourceOffsets[dim] = zeroVal;
312372 staticL3AsSourceSizes[dim] = oneVal;
313- l2ShapeAsTarget[dim] = 1 ;
314373 }
315374
316375 // Traverse each L2->L1 DmaCpyNd op and split them.
@@ -321,34 +380,18 @@ LogicalResult splitLogicalObjectFifos(
321380 l2ToL1DmaOp.getSourceMixedSizes ();
322381
323382 // Now we'll create a new L2 buffer based on the new shape inferred earlier
324- // via `l2ShapeAsTarget`.
325- rewriter.setInsertionPoint (sourceAllocOp);
326- LogicalObjectFifoFromMemrefOp targetObjectFifo =
327- l2ToL1DmaOp.getTargetObjectFifo ();
328- Value targetAllocOp = targetObjectFifo.getMemref ();
329- auto oldSourceMemRefType = cast<MemRefType>(sourceAllocOp.getType ());
330- auto targetMemRefType = cast<MemRefType>(targetAllocOp.getType ());
331- MemRefType newAllocType = MemRefType::get (
332- l2ShapeAsTarget, targetMemRefType.getElementType (),
333- MemRefLayoutAttrInterface{}, oldSourceMemRefType.getMemorySpace ());
334- auto newAllocOp = rewriter.create <memref::AllocOp>(rewriter.getUnknownLoc (),
335- newAllocType);
336- auto newDeallocOp = rewriter.create <memref::DeallocOp>(
337- rewriter.getUnknownLoc (), newAllocOp);
338- newDeallocOp->moveBefore (&newAllocOp->getBlock ()->back ());
339- auto type = cast<MemRefType>(newAllocOp.getType ());
340- // Create new logicalobjectfifo.from_memref for the newly created L2 buffer.
341- rewriter.setInsertionPoint (l2ToL1DmaOp.getSourceObjectFifo ());
342- auto source = rewriter.create <AMDAIE::LogicalObjectFifoFromMemrefOp>(
343- rewriter.getUnknownLoc (), LogicalObjectFifoType::get (type),
344- newAllocOp.getResult (), sourceObjectFifo.getTiles ());
383+ // via `staticL2AsTargetSizes`.
384+ LogicalObjectFifoFromMemrefOp oldL2ObjectFifo =
385+ l2ToL1DmaOp.getSourceObjectFifo ();
386+ AMDAIE::LogicalObjectFifoFromMemrefOp source = createNewLogicalObjectFifo (
387+ rewriter, oldL2ObjectFifo, staticL2AsTargetSizes);
345388
346389 // --------------------------------------------
347390 // ---------- L3 -> L2 splitting --------------
348391 // --------------------------------------------
349392 // Update L3 source offsets for non-split dimensions. Refer doc comment of
350393 // `updateL3SourceOffset` for the computation rationale involved.
351- SmallVector<OpFoldResult, 4 > staticL3AsSourceOffsets =
394+ SmallVector<OpFoldResult> staticL3AsSourceOffsets =
352395 l3ToL2DmaOp.getSourceMixedOffsets ();
353396 for (auto &&[splitDim, nonSplitdim] :
354397 llvm::zip_equal (splitDimsForL2, nonSplitDimsForL2)) {
0 commit comments