Skip to content

Commit cde5ac5

Browse files
committed
Transposed 2d load.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 133ed4d commit cde5ac5

File tree

1 file changed

+59
-19
lines changed

1 file changed

+59
-19
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2528,17 +2528,6 @@ struct LoadOpToBlockIOConversion
25282528
if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE)
25292529
vBlocks = 1;
25302530

2531-
// TODO: use the axis info to general the handling for both regular pointer
2532-
// and block pointer.
2533-
const bool memoryRowMajor = isMemoryRowMajor(op);
2534-
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
2535-
const bool isTransposeRequired = contiguousDim != colDim;
2536-
2537-
if (isTransposeRequired) {
2538-
// TODO: support load column major data.
2539-
return failure();
2540-
}
2541-
25422531
Location loc = op.getLoc();
25432532
auto b = TritonLLVMOpBuilder(loc, rewriter);
25442533
MLIRContext *ctx = rewriter.getContext();
@@ -2667,10 +2656,59 @@ struct LoadOpToBlockIOConversion
26672656
}
26682657
}
26692658

2659+
// TODO: use the axis info to general the handling for both regular pointer
2660+
// and block pointer.
2661+
const bool memoryRowMajor = isMemoryRowMajor(op);
2662+
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
2663+
const bool isTransposeRequired = contiguousDim != colDim;
2664+
2665+
if (isTransposeRequired) {
2666+
if (numPackedVals > 1)
2667+
return failure();
2668+
if (elemSizeInBits > 32)
2669+
return failure();
2670+
if (tileWidth > 32)
2671+
return failure(); // tileWidth is limited to 32 for transpose 2d load.
2672+
2673+
vBlocks = 1;
2674+
2675+
// use the d32 for transpose 2d load.
2676+
packedElemSizeInBits = 32;
2677+
numPackedVals = packedElemSizeInBits / elemSizeInBits;
2678+
if (numPackedVals > 1 && tileWidth != threadsPerWarp)
2679+
return failure(); // Couldn't use the transpose 2d load for un-packable
2680+
// along tile height dim.
2681+
tileHeight = std::min(tileHeight / numPackedVals, 8);
2682+
2683+
if (tileHeight * tileWidth < threadsPerWarp)
2684+
return failure(); // The tile size is not large enough for IGC scalar
2685+
// backend vectorization.
2686+
// transpose the width and height of the tile
2687+
std::swap(tileHeight, tileWidth);
2688+
// if (oneMatrixPerLoadForBT) {
2689+
// // Only load 1 operand per inst on row.
2690+
// numOperandsPer2DLoadM = 1;
2691+
// tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2692+
// } else {
2693+
// // We can decompose the matrix returned by transposed large 2d load
2694+
// // when threads per warp < column size. Otherwise we have to load one
2695+
// // operand per inst.
2696+
// // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2697+
// // now.
2698+
// numOperandsPer2DLoadM =
2699+
// (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2700+
// }
2701+
// // The transpose 2d load only support 1 operand per inst on column.
2702+
// // (vBlocks = 1)
2703+
// numOperandsPer2DloadN = 1;
2704+
// // TODO: support load column major data.
2705+
// return failure();
2706+
}
2707+
26702708
baseWidth = b.i32_val(
26712709
std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8)));
26722710
// If the stride is 0, we want to load only the first row.
2673-
int stride = getStride(ptr, 0);
2711+
int stride = getStride(ptr, memoryRowMajor ? 0 : 1);
26742712
baseHeightInt = (stride == 0 ? 1 : tileHeight);
26752713
baseHeight = b.i32_val(baseHeightInt);
26762714
pitch = getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1);
@@ -2739,17 +2777,19 @@ struct LoadOpToBlockIOConversion
27392777
}
27402778
} break;
27412779
case DpasEncodingAttr::OpIdx::OperandB: {
2742-
assert(numPackedVals == 1 &&
2743-
"invalid number of packed values for DPAS operand B.");
2780+
// assert(numPackedVals == 1 &&
2781+
// "invalid number of packed values for DPAS operand B.");
27442782
unsigned elemsPerLanePerDPASInst =
27452783
product<unsigned>(dpasLayout.getDPASInstShapeB()) / threadsPerWarp;
27462784
// Block 2D contain at least one DotOp B.
27472785
if (numElemsPerLoad >= elemsPerLanePerDPASInst) {
27482786
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
27492787
unsigned sysDepth = dpasLayout.getSystolicDepth();
2750-
if (tileHeight >= (opsPerChannel * sysDepth) &&
2751-
((opsPerChannel == 4 && elemSizeInBits == 8) ||
2752-
(opsPerChannel == 2 && elemSizeInBits == 16))) {
2788+
if ((opsPerChannel == 4 && elemSizeInBits == 8) ||
2789+
(opsPerChannel == 2 && elemSizeInBits == 16)) {
2790+
assert(!isTransposeRequired ||
2791+
opsPerChannel == numPackedVals &&
2792+
"invalid opsPerChannel for transposed DotOp B");
27532793
// Use the VNNI packing format for DotOp B layout.
27542794
numValuesPerLoad = numElemsPerLoad / opsPerChannel;
27552795
packedType = i32_ty;
@@ -2816,8 +2856,8 @@ struct LoadOpToBlockIOConversion
28162856
/*tile_width*/ tileWidth,
28172857
/*tile_height*/ tileHeight,
28182858
/*v_blocks*/ vBlocks,
2819-
/*transpose*/ false,
2820-
/*vnni_transform*/ useVNNIFormat);
2859+
/*transpose*/ isTransposeRequired,
2860+
/*vnni_transform*/ !isTransposeRequired && useVNNIFormat);
28212861

28222862
// When strides[0] is 0, we only want to load the first row, so we
28232863
// set the base height to be 1. If tile height is bigger than 1,

0 commit comments

Comments
 (0)