Skip to content

Commit 8a91a19

Browse files
authored
[HGEMM] HGEMM TN A&B SMEM Swizzle✔️ (#199)
* refactor hgemm mma * refactor hgemm mma * refactor hgemm mma * refactor hgemm mma * refactor hgemm mma * refactor hgemm mma * refactor hgemm mma
1 parent 6c811c9 commit 8a91a19

14 files changed

+1437
-97
lines changed

Diff for: README.md

+6-5
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,12 @@ The kernels listed here will guide you through a step-by-step progression, rangi
306306
| ✔️ [hgemm_wmma_m32n8k16....dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
307307
| ✔️ [hgemm_wmma_m16n16k16...stages*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
308308
| ✔️ [hgemm_wmma_m16n16k16...swizzle*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
309-
| ✔️ [hgemm_mma_m16n8k16...naive*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
310-
| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
311-
| ✔️ [hgemm_mma_m16n8k16...stages*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
312-
| ✔️ [hgemm_mma_m16n8k16...swizzle*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
313-
| ✔️ [hgemm_mma_m16n8k16...swizzle{smem}*](./kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
309+
| ✔️ [hgemm_mma_m16n8k16...naive*](./kernels/hgemm/mma/basic/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
310+
| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./kernels/hgemm/mma/basic/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
311+
| ✔️ [hgemm_mma_m16n8k16...stages*](./kernels/hgemm/mma/basic/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
312+
| ✔️ [hgemm_mma_m16n8k16...swizzle*](./kernels/hgemm/mma/basic/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
313+
| ✔️ [hgemm_mma_m16n8k16...swizzle{smem}*](./kernels/hgemm/mma/swizzle/hgemm_mma_stage_swizzle.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
314+
| ✔️ [hgemm_mma_m16n8k16...swizzle{tn}{smem}*](./kernels/hgemm/mma/swizzle/hgemm_mma_stage_tn_swizzle.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
314315
| ✔️ [hgemm_mma_stages_swizzle{smem}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
315316
| ✔️ [hgemm_mma_cublas*](./kernels/hgemm/cublas/hgemm_cublas.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|
316317

Diff for: kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ template <
2424
typename S2GCopyC,
2525
const bool BlockSwizzle>
2626
__global__ void hgemm_mma_stages_block_swizzle_tn_cute_kernel(
27-
const T *Aptr, const T *Bptr, T *Dptr, int m, int n, int k) {
27+
T *Aptr, T *Bptr, T *Dptr, int m, int n, int k) {
2828
using namespace cute;
2929
// Initilize shared memory
3030
extern __shared__ T shm_data[];
@@ -206,8 +206,8 @@ __global__ void hgemm_mma_stages_block_swizzle_tn_cute_kernel(
206206

207207
// For torch binding, need dynamic block swizzle stride
208208
template <typename T, const int Stages = 2, const bool BlockSwizzle = false>
209-
void launch_hgemm_mma_stages_block_swizzle_tn_cute(const T *a,
210-
const T *b,
209+
void launch_hgemm_mma_stages_block_swizzle_tn_cute(T *a,
210+
T *b,
211211
T *c,
212212
int M,
213213
int N,

Diff for: kernels/hgemm/makefile

+58-6
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,76 @@
11
INCLUDE_DIRS=-I ./utils -I ../../third-party/cutlass/include -I ../../third-party/cutlass/tools/util/include
22
ARCHS=-gencode arch=compute_80,code=sm_80 -gencode arch=compute_89,code=sm_89
3+
ARCHS_80=-gencode arch=compute_80,code=sm_80
34
ARCHS_89=-gencode arch=compute_89,code=sm_89
45
DEFAULT_FLAGS=-O2 $(ARCHS) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
56
DEFAULT_FLAGS_89=-O2 $(ARCHS_89) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
7+
DEFAULT_FLAGS_80=-O2 $(ARCHS_80) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
8+
9+
# Default
610
default:
711
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.bin $(DEFAULT_FLAGS)
812
nvcc cublas/hgemm_cublas.cu -o hgemm_cublas.bin $(DEFAULT_FLAGS)
9-
nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.bin $(DEFAULT_FLAGS)
10-
nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.bin $(DEFAULT_FLAGS)
13+
nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.bin $(DEFAULT_FLAGS)
14+
nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_stage_tn.bin $(DEFAULT_FLAGS)
15+
nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.bin $(DEFAULT_FLAGS)
16+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.bin $(DEFAULT_FLAGS)
17+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.bin $(DEFAULT_FLAGS)
18+
19+
# SM 89
1120
cute_89:
1221
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.bin $(DEFAULT_FLAGS_89)
1322
cute_89_debug:
1423
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.debug.bin $(DEFAULT_FLAGS_89) -DCUTE_HGEMM_DEBUG -Xcompiler "-Wno-format"
24+
# SM 89 NN debug
1525
mma_89:
16-
nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.89.bin $(DEFAULT_FLAGS_89)
26+
nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.89.bin $(DEFAULT_FLAGS_89)
1727
mma_89_debug:
18-
nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
28+
nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
1929
mma_89_swizzle:
20-
nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.bin $(DEFAULT_FLAGS_89)
30+
nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.bin $(DEFAULT_FLAGS_89)
2131
mma_89_swizzle_debug:
22-
nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
32+
nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
33+
# SM 89 TN debug
34+
mma_tn_89:
35+
nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_tn_stage.89.bin $(DEFAULT_FLAGS_89)
36+
mma_tn_89_debug:
37+
nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_tn_stage.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
38+
mma_tn_89_swizzle:
39+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.89.bin $(DEFAULT_FLAGS_89)
40+
mma_tn_89_swizzle_debug:
41+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
42+
mma_tn_89_swizzle_x4:
43+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.89.bin $(DEFAULT_FLAGS_89)
44+
mma_tn_89_swizzle_x4_debug:
45+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
46+
47+
# SM 80
48+
cute_80:
49+
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.80.bin $(DEFAULT_FLAGS_80)
50+
cute_80_debug:
51+
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.80.debug.bin $(DEFAULT_FLAGS_80) -DCUTE_HGEMM_DEBUG -Xcompiler "-Wno-format"
52+
# SM 80 TN debug
53+
mma_80:
54+
nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.80.bin $(DEFAULT_FLAGS_80)
55+
mma_80_debug:
56+
nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG
57+
mma_80_swizzle:
58+
nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.80.bin $(DEFAULT_FLAGS_80)
59+
mma_80_swizzle_debug:
60+
nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG
61+
# SM 80 TN debug
62+
mma_tn_80:
63+
nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_tn_stage.80.bin $(DEFAULT_FLAGS_80)
64+
mma_tn_80_debug:
65+
nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_tn_stage.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG
66+
mma_tn_80_swizzle:
67+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.80.bin $(DEFAULT_FLAGS_80)
68+
mma_tn_80_swizzle_debug:
69+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG
70+
mma_tn_80_swizzle_x4:
71+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.80.bin $(DEFAULT_FLAGS_80)
72+
mma_tn_80_swizzle_x4_debug:
73+
nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG
74+
2375
clean:
2476
rm -rf *.bin

Diff for: kernels/hgemm/mma/basic/.gitignore

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
*.so
2+
*.a
3+
*.dylib
4+
*.dll
5+
*.lib
6+
.DS_Store
7+
build
8+
*.whl
9+
tmp
10+
__pycache__
11+
*.onnx
12+
*.engine
13+
*.pt
14+
*.pth
15+
*.nsys*
16+
*.ncu*
17+
*.sqlite*
18+
*.engine
19+
*.bin
20+
*.out
21+
*bin
22+
bin
23+
output
24+
*.egg-info
25+
*.whl
26+
dist
27+
*.pdf
28+
*.tex
29+
*.log
30+
*.md5
31+
*.aux*
32+
*.dpth
File renamed without changes.

Diff for: kernels/hgemm/mma/hgemm_mma_stage.cu renamed to kernels/hgemm/mma/basic/hgemm_mma_stage.cu

+7-7
Original file line numberDiff line numberDiff line change
@@ -1965,7 +1965,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr_kernel(
19651965
// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block swizzle, dsmem, reg double buffers
19661966
template <const int K_STAGE = 2, const int BLOCK_SWIZZLE_STRIDE = 2048>
19671967
void lanunch_hgemm_mma_m16n8k16_nn(
1968-
const half* a, const half* b, half* c, int M, int N, int K) {
1968+
half* a, half* b, half* c, int M, int N, int K) {
19691969
constexpr int MMA_M = 16;
19701970
constexpr int MMA_N = 8;
19711971
constexpr int MMA_K = 16;
@@ -2167,9 +2167,9 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(
21672167
case 4: // ~34KB
21682168
LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(4, swizzle_stride);
21692169
break;
2170-
case 5: // ~43KB
2171-
LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(5, swizzle_stride);
2172-
break;
2170+
// case 5: // ~43KB
2171+
// LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(5, swizzle_stride);
2172+
// break;
21732173
default:
21742174
LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(2, swizzle_stride);
21752175
break;
@@ -2186,9 +2186,9 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(
21862186
case 4:
21872187
LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(4);
21882188
break;
2189-
case 5:
2190-
LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(5);
2191-
break;
2189+
// case 5:
2190+
// LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(5);
2191+
// break;
21922192
default:
21932193
LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(2);
21942194
break;

Diff for: kernels/hgemm/mma/hgemm_mma_stage_tn.cu renamed to kernels/hgemm/mma/basic/hgemm_mma_stage_tn.cu

+142-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
#include <cuda_bf16.h>
99
#include <cuda_fp8.h>
1010
#include <mma.h>
11-
#include <torch/types.h>
12-
#include <torch/extension.h>
1311
using namespace nvcuda;
1412

1513
#define WARP_SIZE 32
@@ -251,8 +249,8 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel(
251249
int lane_smem_b_k = ((lane_id / 8) % 2) * 8; // 0,8
252250
uint32_t lane_smem_b_ptr = (
253251
smem_b_base_ptr + (stage_sel * s_b_stage_offset +
254-
lane_smem_b_n * (BK + B_PAD) +
255-
lane_smem_b_k) * sizeof(half)
252+
lane_smem_b_n * (BK + B_PAD) +
253+
lane_smem_b_k) * sizeof(half)
256254
);
257255
LDMATRIX_X2(RB[j][0], RB[j][1], lane_smem_b_ptr);
258256
}
@@ -309,7 +307,144 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel(
309307
}
310308
}
311309

310+
// build cpp binary
311+
#ifndef NO_MMA_HGEMM_BIN
312+
313+
#include "utils.h"
314+
315+
// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem, TN
316+
#define LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(stages, stride) \
317+
{ \
318+
const int smem_max_size = ( \
319+
(stages) * BM * (BK + A_PAD) * sizeof(half) + \
320+
(stages) * BN * (BK + B_PAD) * sizeof(half)); \
321+
cudaFuncSetAttribute( \
322+
hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \
323+
MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
324+
WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true>, \
325+
cudaFuncAttributeMaxDynamicSharedMemorySize, \
326+
98304); \
327+
const int N_SWIZZLE = (N + (stride) - 1) / (stride); \
328+
dim3 block(NUM_THREADS); \
329+
dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \
330+
div_ceil(M, BM), \
331+
N_SWIZZLE); \
332+
hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \
333+
MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
334+
WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true><<< \
335+
grid, block, smem_max_size>>>( \
336+
a, b, c, \
337+
M, N, K \
338+
); \
339+
}
340+
341+
template <const int K_STAGE = 2, const int BLOCK_SWIZZLE_STRIDE = 2048>
342+
void lanunch_hgemm_mma_m16n8k16_tn(
343+
half* a, half* b, half* c, int M, int N, int K) {
344+
constexpr int MMA_M = 16;
345+
constexpr int MMA_N = 8;
346+
constexpr int MMA_K = 16;
347+
constexpr int MMA_TILE_M = 2;
348+
constexpr int MMA_TILE_N = 4;
349+
constexpr int WARP_TILE_M = 4;
350+
constexpr int WARP_TILE_N = 4;
351+
constexpr int A_PAD = 0;
352+
constexpr int B_PAD = 0;
353+
constexpr int NUM_THREADS= (
354+
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
355+
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
356+
constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N;
357+
constexpr int BK = MMA_K;
358+
// s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~35KB
359+
// s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB
360+
// s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB
361+
// s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB
362+
LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(
363+
K_STAGE, BLOCK_SWIZZLE_STRIDE);
364+
}
365+
366+
#ifdef HGEMM_MMA_DEBUG
367+
#include <iostream>
368+
#endif
369+
370+
371+
int main(int argc, char *argv[]) {
372+
#ifdef HGEMM_MMA_DEBUG
373+
const int test_num = 1;
374+
#else
375+
const int test_num = 64;
376+
#endif
377+
int M_list[test_num];
378+
int N_list[test_num];
379+
int K_list[test_num];
380+
381+
for (int i = 0; i < test_num; i++) {
382+
M_list[i] = (i + 1) * 256;
383+
N_list[i] = (i + 1) * 256;
384+
K_list[i] = (i + 1) * 256;
385+
}
386+
387+
#ifdef HGEMM_MMA_DEBUG
388+
if (argc > 1) M_list[0] = std::stoi(argv[1]);
389+
if (argc > 2) N_list[0] = std::stoi(argv[2]);
390+
if (argc > 3) K_list[0] = std::stoi(argv[3]);
391+
#endif
392+
393+
#ifdef HGEMM_MMA_DEBUG
394+
int outer_repeat = 1, inner_repeat = 1, warmup = 1;
395+
if (argc > 4) warmup = std::stoi(argv[4]);
396+
if (argc > 5) inner_repeat = std::stoi(argv[5]);
397+
#else
398+
int outer_repeat = 10, inner_repeat = 1, warmup = 1;
399+
#endif
400+
401+
printf("ALGO = MMA16816 HGEMM TN MMA=2x4 WARP=4x4 STAGES=2 BLOCK SWIZZLE=2048\n");
402+
#ifndef HGEMM_MMA_DEBUG
403+
for (int j = 0; j < 5; j++) {
404+
int M = M_list[j], N = N_list[j], K = K_list[j];
405+
float max_error = gemm_error_check_tn<half>(
406+
lanunch_hgemm_mma_m16n8k16_tn<2, 2048>,
407+
M, N, K);
408+
printf("M N K = %6d %6d %6d, ", M, N, K);
409+
printf("Max Error = %f\n", max_error);
410+
}
411+
#endif
412+
413+
for (int j = 0; j < test_num; j++) {
414+
int M = M_list[j], N = N_list[j], K = K_list[j];
415+
416+
double max_sec = 0.0;
417+
double min_sec = DBL_MAX;
418+
double total_sec = 0.0;
419+
420+
for (int k = 0; k < outer_repeat; k++) {
421+
double this_sec = perf_gemm<half>(
422+
lanunch_hgemm_mma_m16n8k16_tn<2, 2048>,
423+
M, N, K, inner_repeat, warmup);
424+
max_sec = max(max_sec, this_sec);
425+
min_sec = min(min_sec, this_sec);
426+
total_sec += this_sec;
427+
}
428+
429+
// 1 TFLOPS = 10^12 FLOPS
430+
// ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
431+
double avg_sec = total_sec / outer_repeat;
432+
double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
433+
434+
printf("M N K = %6d %6d %6d, W = %1d, R = %2d ", M, N, K, warmup, inner_repeat);
435+
printf("Time = %12.8lf %12.8lf %12.8lf s, ", min_sec, avg_sec, max_sec);
436+
printf("AVG Performance = %10.4lf Tflops\n", avg_Tflops);
437+
}
438+
439+
return 0;
440+
}
441+
442+
443+
#else
444+
312445
// --------------------- PyTorch bindings for custom kernel -----------------------
446+
#include <torch/types.h>
447+
#include <torch/extension.h>
313448
#define STRINGFY(str) #str
314449
#define TORCH_BINDING_COMMON_EXTENSION(func) \
315450
m.def(STRINGFY(func), &func, STRINGFY(func));
@@ -398,7 +533,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(
398533
constexpr int WARP_TILE_M = 4;
399534
constexpr int WARP_TILE_N = 4;
400535
constexpr int A_PAD = 0; // 0,8,16
401-
constexpr int B_PAD = 0; // 0,8,16
536+
constexpr int B_PAD = 8; // 0,8,16
402537
constexpr int NUM_THREADS= (
403538
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
404539
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
@@ -446,3 +581,5 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(
446581
}
447582
}
448583
}
584+
585+
#endif

0 commit comments

Comments
 (0)