Skip to content

Commit 942ca37

Browse files
committed
[LoadStoreOpToLLVM] Transposed 2d load.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 0aa1b3c commit 942ca37

File tree

2 files changed

+70
-21
lines changed

2 files changed

+70
-21
lines changed

python/test/unit/intel/test_block_io.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def warps_per_cta(layout):
120120
@pytest.mark.parametrize("layout", layouts)
121121
@pytest.mark.parametrize("load_block_ptr, store_block_ptr", [(True, True), (False, False), (True, False),
122122
(False, True)])
123+
@pytest.mark.parametrize("transpose", [True, False])
123124
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
124-
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, device, tmp_path: pathlib.Path):
125+
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, transpose, device, tmp_path: pathlib.Path):
125126

126127
warps = warps_per_cta(layout)
127128
num_warps = int(np.prod(warps))
@@ -132,16 +133,18 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
132133

133134
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
134135

136+
block_io = "\"column_major\"" if transpose else "\"row_major\""
137+
135138
if load_block_ptr:
136139
load_ops = f"""
137-
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
138-
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
140+
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {"[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]"}, [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
141+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}, boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
139142
"""
140143
else:
141144
load_ops = f"""
142145
%src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
143-
%src_ptr = tt.addptr %src_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
144-
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
146+
%src_ptr = tt.addptr %src_base, {"%col_major_off" if transpose else "%row_major_off" } : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
147+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
145148
"""
146149
if store_block_ptr:
147150
store_ops = f"""
@@ -175,6 +178,12 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
175178
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
176179
%row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
177180
181+
%stride_M = arith.constant dense<{M}> : tensor<1x{N}xi32, #layout>
182+
%col_stride = arith.muli %5, %stride_M : tensor<1x{N}xi32, #layout>
183+
%8 = tt.broadcast %2 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
184+
%9 = tt.broadcast %col_stride : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
185+
%col_major_off = arith.addi %8, %9 : tensor<{M}x{N}xi32, #layout>
186+
178187
{load_ops}
179188
{store_ops}
180189
@@ -195,6 +204,8 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
195204
temp_file.write_text(ir)
196205
kernel = triton.compile(str(temp_file))
197206

207+
a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a
208+
198209
kernel[(1, 1, 1)](a, x)
199210
assert torch.equal(a, x)
200211

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,18 +2521,6 @@ struct LoadOpToBlockIOConversion
25212521
if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE)
25222522
vBlocks = 1;
25232523

2524-
// TODO: use the axis info to general the handling for both regular pointer
2525-
// and block pointer.
2526-
const bool memoryRowMajor = isMemoryRowMajor(op);
2527-
// FIXME: Add support of column major.
2528-
if (!memoryRowMajor)
2529-
return failure();
2530-
2531-
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
2532-
const bool isTransposeRequired = contiguousDim != colDim;
2533-
if (isTransposeRequired)
2534-
return matchAndRewriteTranspose(op, adaptor, rewriter);
2535-
25362524
Location loc = op.getLoc();
25372525
auto b = TritonLLVMOpBuilder(loc, rewriter);
25382526
MLIRContext *ctx = rewriter.getContext();
@@ -2661,6 +2649,55 @@ struct LoadOpToBlockIOConversion
26612649
}
26622650
}
26632651

2652+
// TODO: use the axis info to general the handling for both regular pointer
2653+
// and block pointer.
2654+
const bool memoryRowMajor = isMemoryRowMajor(op);
2655+
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
2656+
const bool isTransposeRequired = contiguousDim != colDim;
2657+
2658+
if (isTransposeRequired) {
2659+
if (numPackedVals > 1)
2660+
return failure();
2661+
if (elemSizeInBits > 32)
2662+
return failure();
2663+
if (tileWidth > 32)
2664+
return failure(); // tileWidth is limited to 32 for transpose 2d load.
2665+
2666+
vBlocks = 1;
2667+
2668+
// use the d32 for transpose 2d load.
2669+
packedElemSizeInBits = 32;
2670+
numPackedVals = packedElemSizeInBits / elemSizeInBits;
2671+
if (numPackedVals > 1 && tileWidth != threadsPerWarp)
2672+
return failure(); // Couldn't use the transpose 2d load for un-packable
2673+
// along tile height dim.
2674+
tileHeight = std::min(tileHeight / numPackedVals, 8);
2675+
2676+
if (tileHeight * tileWidth < threadsPerWarp)
2677+
return failure(); // The tile size is not large enough for IGC scalar
2678+
// backend vectorization.
2679+
// transpose the width and height of the tile
2680+
std::swap(tileHeight, tileWidth);
2681+
// if (oneMatrixPerLoadForBT) {
2682+
// // Only load 1 operand per inst on row.
2683+
// numOperandsPer2DLoadM = 1;
2684+
// tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2685+
// } else {
2686+
// // We can decompose the matrix returned by transposed large 2d load
2687+
// // when threads per warp < column size. Otherwise we have to load one
2688+
// // operand per inst.
2689+
// // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2690+
// // now.
2691+
// numOperandsPer2DLoadM =
2692+
// (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2693+
// }
2694+
// // The transpose 2d load only support 1 operand per inst on column.
2695+
// // (vBlocks = 1)
2696+
// numOperandsPer2DloadN = 1;
2697+
// // TODO: support load column major data.
2698+
// return failure();
2699+
}
2700+
26642701
int64_t numElemsPerLoad = mlir::ceil(
26652702
tileHeight * tileWidth * numPackedVals * vBlocks, (int)threadsPerWarp);
26662703
unsigned numValuesPerLoad = mlir::ceil((int)numElemsPerLoad, numPackedVals);
@@ -2740,8 +2777,6 @@ struct LoadOpToBlockIOConversion
27402777
}
27412778
} break;
27422779
case DpasEncodingAttr::OpIdx::OperandB: {
2743-
assert(numPackedVals == 1 &&
2744-
"invalid number of packed values for DPAS operand B.");
27452780
unsigned elemsPerLanePerDPASInst =
27462781
product<unsigned>(dpasLayout.getDPASInstShapeB()) / threadsPerWarp;
27472782
// Block 2D contain at least one DotOp B.
@@ -2751,6 +2786,9 @@ struct LoadOpToBlockIOConversion
27512786
if (tileHeight >= (opsPerChannel * sysDepth) &&
27522787
((opsPerChannel == 4 && elemSizeInBits == 8) ||
27532788
(opsPerChannel == 2 && elemSizeInBits == 16))) {
2789+
assert(!isTransposeRequired ||
2790+
opsPerChannel == numPackedVals &&
2791+
"invalid opsPerChannel for transposed DotOp B");
27542792
// Use the VNNI packing format for DotOp B layout.
27552793
numValuesPerLoad = numElemsPerLoad / opsPerChannel;
27562794
packedType = i32_ty;
@@ -2814,8 +2852,8 @@ struct LoadOpToBlockIOConversion
28142852
/*tile_width*/ tileWidth,
28152853
/*tile_height*/ tileHeight,
28162854
/*v_blocks*/ vBlocks,
2817-
/*transpose*/ false,
2818-
/*vnni_transform*/ useVNNIFormat);
2855+
/*transpose*/ isTransposeRequired,
2856+
/*vnni_transform*/ !isTransposeRequired && useVNNIFormat);
28192857

28202858
// When strides[0] is 0, we only want to load the first row, so we
28212859
// set the base height to be 1. If tile height is bigger than 1,

0 commit comments

Comments
 (0)