Skip to content

Commit 0e2d12f

Browse files
authored
Convert memref copy to linalg copy (#1098)
1 parent 757164a commit 0e2d12f

File tree

7 files changed

+376
-28
lines changed

7 files changed

+376
-28
lines changed

mlir/include/air/Dialect/AIR/AIRTransformOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,32 @@ def EliminateCascadeMemcpyOp : Op<Transform_Dialect, "air.eliminate_cascade_memc
375375
let assemblyFormat = "$target attr-dict";
376376
}
377377

378+
def ConvertMemrefCopyToLinalgCopyOp : Op<Transform_Dialect, "air.convert_memref_copy_to_linalg_copy",
379+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
380+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
381+
let summary = "Convert memref.copy operations to linalg.copy operations";
382+
let description = [{
383+
This transform converts `memref.copy` operations to `linalg.copy` operations.
384+
This can be useful for enabling further linalg-based optimizations and transformations.
385+
386+
The transformation replaces:
387+
```mlir
388+
memref.copy %source, %dest : memref<...> to memref<...>
389+
```
390+
391+
With:
392+
```mlir
393+
linalg.copy ins(%source : memref<...>) outs(%dest : memref<...>)
394+
```
395+
396+
Returns a handle to the modified operation containing the transformed copies.
397+
}];
398+
399+
let arguments = (ins PDL_Operation:$target);
400+
let results = (outs PDL_Operation:$result);
401+
let assemblyFormat = "$target attr-dict";
402+
}
403+
378404
// Ops implemented in mlir/lib/Transform/AIRLinalgBufferize.cpp
379405
//
380406

mlir/include/air/Util/Util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,8 @@ Operation *cloneOpAndOperands(
316316
return !isa<LoopLikeOpInterface>(o) && !isa<air::HierarchyInterface>(o);
317317
});
318318

319+
bool opOrAncestorIsDominantOver(Operation *a, Operation *b);
320+
319321
} // namespace air
320322
} // namespace xilinx
321323

mlir/lib/Transform/AIRDependency.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -790,18 +790,8 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
790790
operand.getDefiningOp()->emitOpError(
791791
"operand being traced is not a memref");
792792
}
793-
auto opOrAncestorIsDominantOver = [](Operation *a, Operation *b) {
794-
Region *commonRegion = air::findCommonRegionContainingAllAncestors(
795-
SmallVector<Operation *>{a, b}, nullptr);
796-
auto aAncestor = commonRegion->findAncestorOpInRegion(*a);
797-
auto bAncestor = commonRegion->findAncestorOpInRegion(*b);
798-
if (!aAncestor || !bAncestor)
799-
return false;
800-
DominanceInfo domInfo(aAncestor);
801-
return domInfo.properlyDominates(aAncestor, bAncestor);
802-
};
803793
for (auto &u : operand.getUses()) {
804-
if (!opOrAncestorIsDominantOver(u.getOwner(), op))
794+
if (!air::opOrAncestorIsDominantOver(u.getOwner(), op))
805795
continue;
806796
// If used in MemcpyInterface Op
807797
if (auto memcpy = dyn_cast<air::MemcpyInterface>(u.getOwner())) {

mlir/lib/Transform/AIRLinalgCodegen.cpp

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,22 @@ struct RemoveAllocCopyLinalgOpCopyPattern
567567
// #map1> memref.dealloc %11 : memref<1x32x16x16xf32, 2>
568568
//}
569569

570+
struct ConvertMemrefCopyToLinalgCopyPattern
571+
: public OpRewritePattern<memref::CopyOp> {
572+
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
573+
574+
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
575+
PatternRewriter &rewriter) const override {
576+
Value source = copyOp.getSource();
577+
Value target = copyOp.getTarget();
578+
579+
// Create linalg.copy operation
580+
rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, source, target);
581+
582+
return success();
583+
}
584+
};
585+
570586
// Eliminate intermediate memref in cascaded DMA operations
571587
// Replace a pattern like this:
572588
// air.dma_memcpy_nd (%intermediate[] [] [], %source[] [] []) : (memref<...>,
@@ -590,17 +606,6 @@ struct EliminateIntermediateMemrefPattern
590606
if (std::distance(intermediate.use_begin(), intermediate.use_end()) != 2)
591607
return failure();
592608

593-
auto opOrAncestorIsDominantOver = [](Operation *a, Operation *b) {
594-
Region *commonRegion = air::findCommonRegionContainingAllAncestors(
595-
SmallVector<Operation *>{a, b}, nullptr);
596-
auto aAncestor = commonRegion->findAncestorOpInRegion(*a);
597-
auto bAncestor = commonRegion->findAncestorOpInRegion(*b);
598-
if (!aAncestor || !bAncestor)
599-
return false;
600-
DominanceInfo domInfo(aAncestor);
601-
return domInfo.properlyDominates(aAncestor, bAncestor);
602-
};
603-
604609
// Find the second memcpy that uses the intermediate buffer as source
605610
air::DmaMemcpyNdOp secondMemcpy = nullptr;
606611
for (auto user : intermediate.getUsers()) {
@@ -609,7 +614,7 @@ struct EliminateIntermediateMemrefPattern
609614
continue;
610615
if (memcpyOp.getSrcMemref() != intermediate)
611616
continue;
612-
if (opOrAncestorIsDominantOver(memcpyOp, firstMemcpy))
617+
if (air::opOrAncestorIsDominantOver(memcpyOp, firstMemcpy))
613618
continue;
614619
secondMemcpy = memcpyOp;
615620
break;
@@ -2438,13 +2443,10 @@ static bool hasWritesBetween(memref::AllocOp allocOp, Operation *beforeOp) {
24382443

24392444
// Only consider operations that are dominated by the allocation
24402445
// and that dominate the beforeOp
2441-
if (!domInfo.properlyDominates(allocOp.getOperation(), op)) {
2446+
if (!xilinx::air::opOrAncestorIsDominantOver(allocOp.getOperation(), op))
24422447
return;
2443-
}
2444-
2445-
if (!domInfo.properlyDominates(op, beforeOp)) {
2448+
if (!xilinx::air::opOrAncestorIsDominantOver(op, beforeOp))
24462449
return;
2447-
}
24482450

24492451
// Check if this operation writes to our allocation
24502452
if (hasWriteEffectOn(op, allocResult)) {
@@ -2546,6 +2548,41 @@ DiagnosedSilenceableFailure transform::EliminateCascadeMemcpyOp::apply(
25462548
return DiagnosedSilenceableFailure::success();
25472549
}
25482550

2551+
//===----------------------------------------------------------------------===//
2552+
// ConvertMemrefCopyToLinalgCopyOp
2553+
//===----------------------------------------------------------------------===//
2554+
2555+
DiagnosedSilenceableFailure transform::ConvertMemrefCopyToLinalgCopyOp::apply(
2556+
transform::TransformRewriter &rewriter,
2557+
transform::TransformResults &results, transform::TransformState &state) {
2558+
2559+
SmallVector<Operation *> targets =
2560+
llvm::to_vector(state.getPayloadOps(getTarget()));
2561+
2562+
if (targets.empty()) {
2563+
results.set(llvm::cast<OpResult>(getResult()), ArrayRef<Operation *>());
2564+
return DiagnosedSilenceableFailure::success();
2565+
}
2566+
2567+
SmallVector<Operation *> transformedOps;
2568+
2569+
for (Operation *target : targets) {
2570+
MLIRContext *ctx = target->getContext();
2571+
RewritePatternSet patterns(ctx);
2572+
2573+
// Use the ConvertMemrefCopyToLinalgCopyPattern
2574+
patterns.insert<xilinx::air::ConvertMemrefCopyToLinalgCopyPattern>(ctx);
2575+
2576+
// Apply the pattern to convert memref.copy to linalg.copy operations
2577+
(void)applyPatternsGreedily(target, std::move(patterns));
2578+
2579+
transformedOps.push_back(target);
2580+
}
2581+
2582+
results.set(llvm::cast<OpResult>(getResult()), transformedOps);
2583+
return DiagnosedSilenceableFailure::success();
2584+
}
2585+
25492586
namespace xilinx {
25502587
namespace air {
25512588

mlir/lib/Util/Util.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2141,4 +2141,15 @@ Operation *air::cloneOpAndOperands(RewriterBase &rewriter, IRMapping &remap,
21412141
return new_op;
21422142
}
21432143

2144+
bool air::opOrAncestorIsDominantOver(Operation *a, Operation *b) {
2145+
Region *commonRegion = air::findCommonRegionContainingAllAncestors(
2146+
SmallVector<Operation *>{a, b}, nullptr);
2147+
auto aAncestor = commonRegion->findAncestorOpInRegion(*a);
2148+
auto bAncestor = commonRegion->findAncestorOpInRegion(*b);
2149+
if (!aAncestor || !bAncestor)
2150+
return false;
2151+
DominanceInfo domInfo(aAncestor);
2152+
return domInfo.properlyDominates(aAncestor, bAncestor);
2153+
}
2154+
21442155
} // namespace xilinx
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- air_transform.mlir --------------------------------------*- MLIR -*-===//
2+
//
3+
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
// RUN: air-opt %s | FileCheck %s
9+
10+
// CHECK: transform.air.convert_memref_copy_to_linalg_copy
11+
12+
transform.with_pdl_patterns {
13+
^bb0(%arg0: !pdl.operation):
14+
transform.sequence %arg0 : !pdl.operation failures(propagate) {
15+
^bb1(%arg1: !pdl.operation):
16+
// Convert memref.copy to linalg.copy
17+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
18+
%func_op_updated = transform.air.convert_memref_copy_to_linalg_copy %func_op
19+
}
20+
}

0 commit comments

Comments
 (0)