@@ -567,6 +567,22 @@ struct RemoveAllocCopyLinalgOpCopyPattern
567567// #map1> memref.dealloc %11 : memref<1x32x16x16xf32, 2>
568568// }
569569
570+ struct ConvertMemrefCopyToLinalgCopyPattern
571+ : public OpRewritePattern<memref::CopyOp> {
572+ using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
573+
574+ LogicalResult matchAndRewrite (memref::CopyOp copyOp,
575+ PatternRewriter &rewriter) const override {
576+ Value source = copyOp.getSource ();
577+ Value target = copyOp.getTarget ();
578+
579+ // Create linalg.copy operation
580+ rewriter.replaceOpWithNewOp <linalg::CopyOp>(copyOp, source, target);
581+
582+ return success ();
583+ }
584+ };
585+
570586// Eliminate intermediate memref in cascaded DMA operations
571587// Replace a pattern like this:
572588// air.dma_memcpy_nd (%intermediate[] [] [], %source[] [] []) : (memref<...>,
@@ -590,17 +606,6 @@ struct EliminateIntermediateMemrefPattern
590606 if (std::distance (intermediate.use_begin (), intermediate.use_end ()) != 2 )
591607 return failure ();
592608
593- auto opOrAncestorIsDominantOver = [](Operation *a, Operation *b) {
594- Region *commonRegion = air::findCommonRegionContainingAllAncestors (
595- SmallVector<Operation *>{a, b}, nullptr );
596- auto aAncestor = commonRegion->findAncestorOpInRegion (*a);
597- auto bAncestor = commonRegion->findAncestorOpInRegion (*b);
598- if (!aAncestor || !bAncestor)
599- return false ;
600- DominanceInfo domInfo (aAncestor);
601- return domInfo.properlyDominates (aAncestor, bAncestor);
602- };
603-
604609 // Find the second memcpy that uses the intermediate buffer as source
605610 air::DmaMemcpyNdOp secondMemcpy = nullptr ;
606611 for (auto user : intermediate.getUsers ()) {
@@ -609,7 +614,7 @@ struct EliminateIntermediateMemrefPattern
609614 continue ;
610615 if (memcpyOp.getSrcMemref () != intermediate)
611616 continue ;
612- if (opOrAncestorIsDominantOver (memcpyOp, firstMemcpy))
617+ if (air:: opOrAncestorIsDominantOver (memcpyOp, firstMemcpy))
613618 continue ;
614619 secondMemcpy = memcpyOp;
615620 break ;
@@ -2438,13 +2443,10 @@ static bool hasWritesBetween(memref::AllocOp allocOp, Operation *beforeOp) {
24382443
24392444 // Only consider operations that are dominated by the allocation
24402445 // and that dominate the beforeOp
2441- if (!domInfo. properlyDominates (allocOp.getOperation (), op)) {
2446+ if (!xilinx::air::opOrAncestorIsDominantOver (allocOp.getOperation (), op))
24422447 return ;
2443- }
2444-
2445- if (!domInfo.properlyDominates (op, beforeOp)) {
2448+ if (!xilinx::air::opOrAncestorIsDominantOver (op, beforeOp))
24462449 return ;
2447- }
24482450
24492451 // Check if this operation writes to our allocation
24502452 if (hasWriteEffectOn (op, allocResult)) {
@@ -2546,6 +2548,41 @@ DiagnosedSilenceableFailure transform::EliminateCascadeMemcpyOp::apply(
25462548 return DiagnosedSilenceableFailure::success ();
25472549}
25482550
2551+ // ===----------------------------------------------------------------------===//
2552+ // ConvertMemrefCopyToLinalgCopyOp
2553+ // ===----------------------------------------------------------------------===//
2554+
2555+ DiagnosedSilenceableFailure transform::ConvertMemrefCopyToLinalgCopyOp::apply (
2556+ transform::TransformRewriter &rewriter,
2557+ transform::TransformResults &results, transform::TransformState &state) {
2558+
2559+ SmallVector<Operation *> targets =
2560+ llvm::to_vector (state.getPayloadOps (getTarget ()));
2561+
2562+ if (targets.empty ()) {
2563+ results.set (llvm::cast<OpResult>(getResult ()), ArrayRef<Operation *>());
2564+ return DiagnosedSilenceableFailure::success ();
2565+ }
2566+
2567+ SmallVector<Operation *> transformedOps;
2568+
2569+ for (Operation *target : targets) {
2570+ MLIRContext *ctx = target->getContext ();
2571+ RewritePatternSet patterns (ctx);
2572+
2573+ // Use the ConvertMemrefCopyToLinalgCopyPattern
2574+ patterns.insert <xilinx::air::ConvertMemrefCopyToLinalgCopyPattern>(ctx);
2575+
2576+ // Apply the pattern to convert memref.copy to linalg.copy operations
2577+ (void )applyPatternsGreedily (target, std::move (patterns));
2578+
2579+ transformedOps.push_back (target);
2580+ }
2581+
2582+ results.set (llvm::cast<OpResult>(getResult ()), transformedOps);
2583+ return DiagnosedSilenceableFailure::success ();
2584+ }
2585+
25492586namespace xilinx {
25502587namespace air {
25512588
0 commit comments