8
8
#include < cuda_bf16.h>
9
9
#include < cuda_fp8.h>
10
10
#include < mma.h>
11
- #include < torch/types.h>
12
- #include < torch/extension.h>
13
11
using namespace nvcuda ;
14
12
15
13
#define WARP_SIZE 32
@@ -251,8 +249,8 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel(
251
249
int lane_smem_b_k = ((lane_id / 8 ) % 2 ) * 8 ; // 0,8
252
250
uint32_t lane_smem_b_ptr = (
253
251
smem_b_base_ptr + (stage_sel * s_b_stage_offset +
254
- lane_smem_b_n * (BK + B_PAD) +
255
- lane_smem_b_k) * sizeof (half)
252
+ lane_smem_b_n * (BK + B_PAD) +
253
+ lane_smem_b_k) * sizeof (half)
256
254
);
257
255
LDMATRIX_X2 (RB[j][0 ], RB[j][1 ], lane_smem_b_ptr);
258
256
}
@@ -309,7 +307,144 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel(
309
307
}
310
308
}
311
309
310
+ // build cpp binary
311
+ #ifndef NO_MMA_HGEMM_BIN
312
+
313
+ #include " utils.h"
314
+
315
+ // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem, TN
316
+ #define LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL (stages, stride ) \
317
+ { \
318
+ const int smem_max_size = ( \
319
+ (stages) * BM * (BK + A_PAD) * sizeof (half) + \
320
+ (stages) * BN * (BK + B_PAD) * sizeof (half)); \
321
+ cudaFuncSetAttribute ( \
322
+ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \
323
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
324
+ WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true >, \
325
+ cudaFuncAttributeMaxDynamicSharedMemorySize, \
326
+ 98304 ); \
327
+ const int N_SWIZZLE = (N + (stride) - 1 ) / (stride); \
328
+ dim3 block (NUM_THREADS); \
329
+ dim3 grid ((div_ceil (N, BN) + N_SWIZZLE - 1 ) / N_SWIZZLE, \
330
+ div_ceil (M, BM), \
331
+ N_SWIZZLE); \
332
+ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \
333
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
334
+ WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true ><<< \
335
+ grid, block, smem_max_size>>> ( \
336
+ a, b, c, \
337
+ M, N, K \
338
+ ); \
339
+ }
340
+
341
+ template <const int K_STAGE = 2 , const int BLOCK_SWIZZLE_STRIDE = 2048 >
342
+ void lanunch_hgemm_mma_m16n8k16_tn (
343
+ half* a, half* b, half* c, int M, int N, int K) {
344
+ constexpr int MMA_M = 16 ;
345
+ constexpr int MMA_N = 8 ;
346
+ constexpr int MMA_K = 16 ;
347
+ constexpr int MMA_TILE_M = 2 ;
348
+ constexpr int MMA_TILE_N = 4 ;
349
+ constexpr int WARP_TILE_M = 4 ;
350
+ constexpr int WARP_TILE_N = 4 ;
351
+ constexpr int A_PAD = 0 ;
352
+ constexpr int B_PAD = 0 ;
353
+ constexpr int NUM_THREADS= (
354
+ MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
355
+ constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
356
+ constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N;
357
+ constexpr int BK = MMA_K;
358
+ // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~35KB
359
+ // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB
360
+ // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB
361
+ // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB
362
+ LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL (
363
+ K_STAGE, BLOCK_SWIZZLE_STRIDE);
364
+ }
365
+
366
+ #ifdef HGEMM_MMA_DEBUG
367
+ #include < iostream>
368
+ #endif
369
+
370
+
371
+ int main (int argc, char *argv[]) {
372
+ #ifdef HGEMM_MMA_DEBUG
373
+ const int test_num = 1 ;
374
+ #else
375
+ const int test_num = 64 ;
376
+ #endif
377
+ int M_list[test_num];
378
+ int N_list[test_num];
379
+ int K_list[test_num];
380
+
381
+ for (int i = 0 ; i < test_num; i++) {
382
+ M_list[i] = (i + 1 ) * 256 ;
383
+ N_list[i] = (i + 1 ) * 256 ;
384
+ K_list[i] = (i + 1 ) * 256 ;
385
+ }
386
+
387
+ #ifdef HGEMM_MMA_DEBUG
388
+ if (argc > 1 ) M_list[0 ] = std::stoi (argv[1 ]);
389
+ if (argc > 2 ) N_list[0 ] = std::stoi (argv[2 ]);
390
+ if (argc > 3 ) K_list[0 ] = std::stoi (argv[3 ]);
391
+ #endif
392
+
393
+ #ifdef HGEMM_MMA_DEBUG
394
+ int outer_repeat = 1 , inner_repeat = 1 , warmup = 1 ;
395
+ if (argc > 4 ) warmup = std::stoi (argv[4 ]);
396
+ if (argc > 5 ) inner_repeat = std::stoi (argv[5 ]);
397
+ #else
398
+ int outer_repeat = 10 , inner_repeat = 1 , warmup = 1 ;
399
+ #endif
400
+
401
+ printf (" ALGO = MMA16816 HGEMM TN MMA=2x4 WARP=4x4 STAGES=2 BLOCK SWIZZLE=2048\n " );
402
+ #ifndef HGEMM_MMA_DEBUG
403
+ for (int j = 0 ; j < 5 ; j++) {
404
+ int M = M_list[j], N = N_list[j], K = K_list[j];
405
+ float max_error = gemm_error_check_tn<half>(
406
+ lanunch_hgemm_mma_m16n8k16_tn<2 , 2048 >,
407
+ M, N, K);
408
+ printf (" M N K = %6d %6d %6d, " , M, N, K);
409
+ printf (" Max Error = %f\n " , max_error);
410
+ }
411
+ #endif
412
+
413
+ for (int j = 0 ; j < test_num; j++) {
414
+ int M = M_list[j], N = N_list[j], K = K_list[j];
415
+
416
+ double max_sec = 0.0 ;
417
+ double min_sec = DBL_MAX;
418
+ double total_sec = 0.0 ;
419
+
420
+ for (int k = 0 ; k < outer_repeat; k++) {
421
+ double this_sec = perf_gemm<half>(
422
+ lanunch_hgemm_mma_m16n8k16_tn<2 , 2048 >,
423
+ M, N, K, inner_repeat, warmup);
424
+ max_sec = max (max_sec, this_sec);
425
+ min_sec = min (min_sec, this_sec);
426
+ total_sec += this_sec;
427
+ }
428
+
429
+ // 1 TFLOPS = 10^12 FLOPS
430
+ // ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
431
+ double avg_sec = total_sec / outer_repeat;
432
+ double avg_Tflops = ((double )M) * N * K * 2 * 1e-12 / avg_sec;
433
+
434
+ printf (" M N K = %6d %6d %6d, W = %1d, R = %2d " , M, N, K, warmup, inner_repeat);
435
+ printf (" Time = %12.8lf %12.8lf %12.8lf s, " , min_sec, avg_sec, max_sec);
436
+ printf (" AVG Performance = %10.4lf Tflops\n " , avg_Tflops);
437
+ }
438
+
439
+ return 0 ;
440
+ }
441
+
442
+
443
+ #else
444
+
312
445
// --------------------- PyTorch bindings for custom kernel -----------------------
446
+ #include < torch/types.h>
447
+ #include < torch/extension.h>
313
448
#define STRINGFY (str ) #str
314
449
#define TORCH_BINDING_COMMON_EXTENSION (func ) \
315
450
m.def(STRINGFY(func), &func, STRINGFY(func));
@@ -398,7 +533,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(
398
533
constexpr int WARP_TILE_M = 4 ;
399
534
constexpr int WARP_TILE_N = 4 ;
400
535
constexpr int A_PAD = 0 ; // 0,8,16
401
- constexpr int B_PAD = 0 ; // 0,8,16
536
+ constexpr int B_PAD = 8 ; // 0,8,16
402
537
constexpr int NUM_THREADS= (
403
538
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
404
539
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
@@ -446,3 +581,5 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(
446
581
}
447
582
}
448
583
}
584
+
585
+ #endif
0 commit comments