Skip to content

Commit f384223

Browse files
authored
Fixup issue where AIR only optimizes SHIM DMA BD for wrap < 1024 (Xilinx#606)
* Enable constraint where highest wrap is no greater than 64 * Fixup issue where the outermost wrap-and-stride dim gets lost when npu dma op gets tiled at outermost wrap dimension * Fixup issue on for loop step size when stride = 0 * Fixup an issue where dim with stride = 0 was converted into for loop with step = 1, which then got passed into offset * Add test; change existing tests to reflect optimized bd allocation
1 parent b2df4d7 commit f384223

File tree

3 files changed

+119
-28
lines changed

3 files changed

+119
-28
lines changed

mlir/lib/Conversion/AIRRtToNpuPass.cpp

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -513,21 +513,21 @@ void isolateAIRRtDmaLoopNests(ModuleOp module) {
513513
}
514514

515515
// AIE2 hardware constraints.
516-
const int AIE2_WRAP_UPPER_BOUND = 1024;
516+
const std::vector<int> AIE2_WRAP_UPPER_BOUNDS = {64, 1024, 1024, 1024};
517+
const int AIE2_STRIDE_UPPER_BOUND = 1048576;
517518
const int AIE2_DIM_COUNT = 4;
518519

519520
bool violatesAIE2WrapLimit(airrt::DmaMemcpyNdOp dma) {
520521
SmallVector<Value> wrap_list;
521-
wrap_list.push_back(dma.getLength0());
522-
wrap_list.push_back(dma.getLength1());
523-
wrap_list.push_back(dma.getLength2());
524522
wrap_list.push_back(dma.getLength3());
525-
for (auto wrap : wrap_list) {
526-
if (auto const_val = getConstantIntValue(wrap)) {
523+
wrap_list.push_back(dma.getLength2());
524+
wrap_list.push_back(dma.getLength1());
525+
wrap_list.push_back(dma.getLength0());
526+
for (unsigned i = 0; i < wrap_list.size(); i++) {
527+
if (auto const_val = getConstantIntValue(wrap_list[i])) {
527528
// Detected wrap that goes beyond the AIE2 hardware limit.
528-
if (*const_val >= AIE2_WRAP_UPPER_BOUND) {
529+
if (*const_val >= AIE2_WRAP_UPPER_BOUNDS[i])
529530
return true;
530-
}
531531
} else
532532
assert(false && "has non-static wrap");
533533
}
@@ -567,6 +567,7 @@ int findLargestFactor(int num, int max) {
567567

568568
void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
569569
auto loc = memcpy_op->getLoc();
570+
auto ctx = memcpy_op->getContext();
570571
auto oper_begin = memcpy_op.getOperands().begin();
571572
SmallVector<Value> offsets(oper_begin + 4, oper_begin + 8);
572573
SmallVector<Value> wraps(oper_begin + 8, oper_begin + 12);
@@ -579,10 +580,20 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
579580
for (int i = wraps.size() - 1; i >= 0; i--) {
580581
auto const_wrap = *getConstantIntValue(wraps[i]);
581582
auto const_stride = *getConstantIntValue(strides[i]);
582-
if (const_wrap >= AIE2_WRAP_UPPER_BOUND) {
583-
// Found dimension with illegal wrap. Tiling.
584-
int outer_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1);
585-
int inner_wrap = mlir::ceilDiv(const_wrap, outer_wrap);
583+
if (const_wrap >= AIE2_WRAP_UPPER_BOUNDS[i]) {
584+
// Found dimension with illegal wrap. Tiling. (Prefers smaller outer wrap
585+
// values, as long as stride fits)
586+
int a_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUNDS[i] - 1);
587+
int b_wrap = mlir::ceilDiv(const_wrap, a_wrap);
588+
int new_a_stride =
589+
(const_stride * a_wrap) % air::getTensorVolume(llvm::cast<MemRefType>(
590+
memcpy_op.getMemref().getType()));
591+
int inner_wrap = (new_a_stride > AIE2_STRIDE_UPPER_BOUND && i != 0)
592+
? (b_wrap)
593+
: (a_wrap);
594+
int outer_wrap = (new_a_stride > AIE2_STRIDE_UPPER_BOUND && i != 0)
595+
? (a_wrap)
596+
: (b_wrap);
586597
wraps[i] = builder.create<arith::ConstantOp>(
587598
loc, builder.getI64Type(),
588599
IntegerAttr::get(builder.getI64Type(), inner_wrap));
@@ -609,29 +620,52 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
609620
// Unroll highest dimensions of wrap and stride, if the new dimension count
610621
// goes beyond 4.
611622
SmallVector<affine::AffineForOp> for_loop_nest;
623+
Value inner_affine_for_iv = nullptr;
612624
if (wraps.size() > AIE2_DIM_COUNT) {
613625
affine::AffineForOp inner_affine_for = nullptr;
614626
while (wraps.size() > AIE2_DIM_COUNT) {
615627
auto const_offset = *getConstantIntValue(offsets[0]);
628+
auto const_lowest_offset = *getConstantIntValue(offsets.back());
616629
auto const_wrap = *getConstantIntValue(wraps[0]);
617630
auto const_stride = *getConstantIntValue(strides[0]);
618631

619632
// Convert the outer dimension into an affine.for loop.
620-
auto const_upper_bound = const_offset + const_wrap * const_stride;
633+
int const_lower_bound =
634+
const_stride ? (const_offset * const_stride + const_lowest_offset)
635+
: 0;
636+
auto const_upper_bound =
637+
const_stride ? (const_offset * const_stride +
638+
const_wrap * const_stride + const_lowest_offset)
639+
: const_wrap;
640+
int const_step = const_stride ? const_stride : 1;
621641
auto new_for_op =
622-
(const_stride)
642+
(inner_affine_for_iv)
623643
? (builder.create<affine::AffineForOp>(
624-
loc, const_offset, const_upper_bound, const_stride))
625-
: (builder.create<affine::AffineForOp>(loc, 0, const_wrap));
644+
loc,
645+
SmallVector<Value>{builder.create<arith::AddIOp>(
646+
loc, inner_affine_for_iv,
647+
builder.create<arith::ConstantIndexOp>(
648+
loc, const_lower_bound))},
649+
AffineMap::get(ctx),
650+
SmallVector<Value>{builder.create<arith::AddIOp>(
651+
loc, inner_affine_for_iv,
652+
builder.create<arith::ConstantIndexOp>(
653+
loc, const_upper_bound))},
654+
AffineMap::get(ctx), const_step))
655+
: (builder.create<affine::AffineForOp>(
656+
loc, const_lower_bound, const_upper_bound, const_step));
626657
for_loop_nest.push_back(new_for_op);
627658
inner_affine_for = new_for_op;
628659

629660
// Pop front.
630661
offsets.erase(offsets.begin());
631662
wraps.erase(wraps.begin());
632663
strides.erase(strides.begin());
664+
665+
builder.setInsertionPointToStart(inner_affine_for.getBody());
666+
if (const_stride)
667+
inner_affine_for_iv = inner_affine_for.getInductionVar();
633668
}
634-
builder.setInsertionPointToStart(inner_affine_for.getBody());
635669
}
636670

637671
// Stride field implicit last element one, pop.
@@ -641,8 +675,20 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
641675
SmallVector<Value> new_opers;
642676
SmallVector<Type> tys;
643677
auto old_opers = memcpy_op.getOperands();
678+
// Insert
644679
new_opers.insert(new_opers.end(), old_opers.begin(), old_opers.begin() + 4);
645-
new_opers.insert(new_opers.end(), offsets.begin(), offsets.end());
680+
if (inner_affine_for_iv) {
681+
// Innermost tiled affine.for loop induction variable as lowest offset, if
682+
// original rank exceeds hw limit.
683+
new_opers.insert(new_opers.end(), offsets.begin(), offsets.end() - 1);
684+
auto new_inner_offset = builder.create<arith::AddIOp>(
685+
loc,
686+
builder.create<arith::IndexCastOp>(loc, IntegerType::get(ctx, 64),
687+
inner_affine_for_iv),
688+
offsets.back());
689+
new_opers.push_back(new_inner_offset);
690+
} else
691+
new_opers.insert(new_opers.end(), offsets.begin(), offsets.end());
646692
new_opers.insert(new_opers.end(), wraps.begin(), wraps.end());
647693
new_opers.insert(new_opers.end(), strides.begin(), strides.end());
648694
builder.create<airrt::DmaMemcpyNdOp>(loc, tys, new_opers,
@@ -909,6 +955,11 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
909955
// Enforce AIE2 hardware constraint: wrap size limit within [0, 1023].
910956
enforceAIE2WrapLimit(module);
911957

958+
// Simplify arith ops (from airrt)
959+
RewritePatternSet canoPatterns_3(ctx);
960+
arith::IndexCastOp::getCanonicalizationPatterns(canoPatterns_3, ctx);
961+
(void)applyPatternsAndFoldGreedily(module, std::move(canoPatterns_3));
962+
912963
ConversionTarget target(getContext());
913964
target.addIllegalDialect<AIRRtDialect>();
914965
target.addLegalDialect<arith::ArithDialect, AIEX::AIEXDialect>();

mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,10 @@ module {
455455
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 64, 0][4, 8, 64, 256][0, 256, 2048]) {id = 1 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
456456
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 128, 0][4, 8, 64, 256][0, 256, 2048]) {id = 2 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
457457
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 192, 0][4, 8, 64, 256][0, 256, 2048]) {id = 3 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
458-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
459-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
460-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
461-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
458+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
459+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
460+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
461+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
462462
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG2]][0, 0, 0, 0][4, 4, 64, 64][131072, 64, 2048]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2048x2048xi32>
463463

464464
#map = affine_map<()[s0] -> (s0 * 64)>
@@ -701,8 +701,8 @@ module {
701701
// CHECK-SAME: %[[VAL_0:.*]]: memref<262144xi32>, %[[VAL_1:.*]]: memref<262144xi32>, %[[VAL_2:.*]]: memref<131072xi32>) {
702702
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][2, 4, 256, 128][0, 128, 512]) {id = 0 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
703703
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][2, 4, 256, 128][0, 128, 512]) {id = 1 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
704-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
705-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
704+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
705+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
706706
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][2, 2, 64, 128][65536, 128, 256]) {id = 4 : i64, metadata = @airMemcpyId45} : memref<131072xi32>
707707
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 16384][2, 2, 64, 128][65536, 128, 256]) {id = 5 : i64, metadata = @airMemcpyId46} : memref<131072xi32>
708708
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 32768][2, 2, 64, 128][65536, 128, 256]) {id = 0 : i64, metadata = @airMemcpyId47} : memref<131072xi32>
@@ -930,3 +930,43 @@ module {
930930
return
931931
}
932932
}
933+
934+
// -----
935+
936+
// Outermost wrap must be in range [1:64] for AIE2.
937+
938+
// CHECK-LABEL: func21
939+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][38, 2, 64, 32][77824, 32, 1216]) {id = 0 : i64, metadata = @airMemcpyId10} : memref<11829248xi32>
940+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 2957312][38, 2, 64, 32][77824, 32, 1216]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<11829248xi32>
941+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 5914624][38, 2, 64, 32][77824, 32, 1216]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<11829248xi32>
942+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 8871936][38, 2, 64, 32][77824, 32, 1216]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<11829248xi32>
943+
// CHECK: return
944+
945+
#map = affine_map<()[s0] -> (s0 * 128)>
946+
module {
947+
aie.device(npu1_4col) {
948+
aie.shim_dma_allocation @airMemcpyId10(MM2S, 1, 0)
949+
memref.global "public" @airMemcpyId10 : memref<1x2x64x64xbf16, 1 : i32>
950+
} {sym_name = "matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0"}
951+
airrt.module_metadata{
952+
}
953+
func.func @func21(%arg0: memref<9728x2432xbf16>) {
954+
%c2_i64 = arith.constant 2 : i64
955+
%c2432_i64 = arith.constant 2432 : i64
956+
%c155648_i64 = arith.constant 155648 : i64
957+
%c152_i64 = arith.constant 152 : i64
958+
%c64_i64 = arith.constant 64 : i64
959+
%c10_i32 = arith.constant 10 : i32
960+
%c0_i64 = arith.constant 0 : i64
961+
affine.for %arg3 = 0 to 1 {
962+
affine.for %arg4 = 0 to 1 {
963+
%0 = affine.apply #map()[%arg4]
964+
%1 = arith.index_cast %arg3 : index to i64
965+
%2 = arith.index_cast %arg4 : index to i64
966+
%3 = arith.index_cast %0 : index to i64
967+
%4 = airrt.dma_memcpy_nd(%c10_i32, %1, %2, %arg0[%c0_i64, %c0_i64, %c0_i64, %3], [%c152_i64, %c2_i64, %c64_i64, %c64_i64], [%c155648_i64, %c64_i64, %c2432_i64]) {metadata = @airMemcpyId10} : (i32, i64, i64, memref<9728x2432xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
968+
}
969+
}
970+
return
971+
}
972+
}

mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ module {
122122
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][4, 8, 128, 128][0, 128, 1024]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
123123
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 262144][4, 8, 128, 128][0, 128, 1024]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
124124
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 393216][4, 8, 128, 128][0, 128, 1024]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
125-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
126-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
127-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
128-
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
125+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
126+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
127+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
128+
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
129129
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][4, 4, 128, 64][131072, 64, 1024]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2097152xi32>
130130

131131
module {

0 commit comments

Comments
 (0)