88#include < cuda_bf16.h>
99#include < cuda_fp8.h>
1010#include < mma.h>
11- #include < torch/types.h>
12- #include < torch/extension.h>
1311using namespace nvcuda ;
1412
1513#define WARP_SIZE 32
@@ -251,8 +249,8 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel(
251249 int lane_smem_b_k = ((lane_id / 8 ) % 2 ) * 8 ; // 0,8
252250 uint32_t lane_smem_b_ptr = (
253251 smem_b_base_ptr + (stage_sel * s_b_stage_offset +
254- lane_smem_b_n * (BK + B_PAD) +
255- lane_smem_b_k) * sizeof (half)
252+ lane_smem_b_n * (BK + B_PAD) +
253+ lane_smem_b_k) * sizeof (half)
256254 );
257255 LDMATRIX_X2 (RB[j][0 ], RB[j][1 ], lane_smem_b_ptr);
258256 }
@@ -309,7 +307,144 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel(
309307 }
310308}
311309
310+ // build cpp binary
311+ #ifndef NO_MMA_HGEMM_BIN
312+
313+ #include " utils.h"
314+
315+ // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem, TN
316+ #define LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL (stages, stride ) \
317+ { \
318+ const int smem_max_size = ( \
319+ (stages) * BM * (BK + A_PAD) * sizeof (half) + \
320+ (stages) * BN * (BK + B_PAD) * sizeof (half)); \
321+ cudaFuncSetAttribute ( \
322+ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \
323+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
324+ WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true >, \
325+ cudaFuncAttributeMaxDynamicSharedMemorySize, \
326+ 98304 ); \
327+ const int N_SWIZZLE = (N + (stride) - 1 ) / (stride); \
328+ dim3 block (NUM_THREADS); \
329+ dim3 grid ((div_ceil (N, BN) + N_SWIZZLE - 1 ) / N_SWIZZLE, \
330+ div_ceil (M, BM), \
331+ N_SWIZZLE); \
332+ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \
333+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
334+ WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true ><<< \
335+ grid, block, smem_max_size>>> ( \
336+ a, b, c, \
337+ M, N, K \
338+ ); \
339+ }
340+
341+ template <const int K_STAGE = 2 , const int BLOCK_SWIZZLE_STRIDE = 2048 >
342+ void lanunch_hgemm_mma_m16n8k16_tn (
343+ half* a, half* b, half* c, int M, int N, int K) {
344+ constexpr int MMA_M = 16 ;
345+ constexpr int MMA_N = 8 ;
346+ constexpr int MMA_K = 16 ;
347+ constexpr int MMA_TILE_M = 2 ;
348+ constexpr int MMA_TILE_N = 4 ;
349+ constexpr int WARP_TILE_M = 4 ;
350+ constexpr int WARP_TILE_N = 4 ;
351+ constexpr int A_PAD = 0 ;
352+ constexpr int B_PAD = 0 ;
353+ constexpr int NUM_THREADS= (
354+ MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
355+ constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
356+ constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N;
357+ constexpr int BK = MMA_K;
358+ // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~35KB
359+ // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB
360+ // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB
361+ // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB
362+ LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL (
363+ K_STAGE, BLOCK_SWIZZLE_STRIDE);
364+ }
365+
366+ #ifdef HGEMM_MMA_DEBUG
367+ #include < iostream>
368+ #endif
369+
370+
371+ int main (int argc, char *argv[]) {
372+ #ifdef HGEMM_MMA_DEBUG
373+ const int test_num = 1 ;
374+ #else
375+ const int test_num = 64 ;
376+ #endif
377+ int M_list[test_num];
378+ int N_list[test_num];
379+ int K_list[test_num];
380+
381+ for (int i = 0 ; i < test_num; i++) {
382+ M_list[i] = (i + 1 ) * 256 ;
383+ N_list[i] = (i + 1 ) * 256 ;
384+ K_list[i] = (i + 1 ) * 256 ;
385+ }
386+
387+ #ifdef HGEMM_MMA_DEBUG
388+ if (argc > 1 ) M_list[0 ] = std::stoi (argv[1 ]);
389+ if (argc > 2 ) N_list[0 ] = std::stoi (argv[2 ]);
390+ if (argc > 3 ) K_list[0 ] = std::stoi (argv[3 ]);
391+ #endif
392+
393+ #ifdef HGEMM_MMA_DEBUG
394+ int outer_repeat = 1 , inner_repeat = 1 , warmup = 1 ;
395+ if (argc > 4 ) warmup = std::stoi (argv[4 ]);
396+ if (argc > 5 ) inner_repeat = std::stoi (argv[5 ]);
397+ #else
398+ int outer_repeat = 10 , inner_repeat = 1 , warmup = 1 ;
399+ #endif
400+
401+ printf (" ALGO = MMA16816 HGEMM TN MMA=2x4 WARP=4x4 STAGES=2 BLOCK SWIZZLE=2048\n " );
402+ #ifndef HGEMM_MMA_DEBUG
403+ for (int j = 0 ; j < 5 ; j++) {
404+ int M = M_list[j], N = N_list[j], K = K_list[j];
405+ float max_error = gemm_error_check_tn<half>(
406+ lanunch_hgemm_mma_m16n8k16_tn<2 , 2048 >,
407+ M, N, K);
408+ printf (" M N K = %6d %6d %6d, " , M, N, K);
409+ printf (" Max Error = %f\n " , max_error);
410+ }
411+ #endif
412+
413+ for (int j = 0 ; j < test_num; j++) {
414+ int M = M_list[j], N = N_list[j], K = K_list[j];
415+
416+ double max_sec = 0.0 ;
417+ double min_sec = DBL_MAX;
418+ double total_sec = 0.0 ;
419+
420+ for (int k = 0 ; k < outer_repeat; k++) {
421+ double this_sec = perf_gemm<half>(
422+ lanunch_hgemm_mma_m16n8k16_tn<2 , 2048 >,
423+ M, N, K, inner_repeat, warmup);
424+ max_sec = max (max_sec, this_sec);
425+ min_sec = min (min_sec, this_sec);
426+ total_sec += this_sec;
427+ }
428+
429+ // 1 TFLOPS = 10^12 FLOPS
430+ // ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
431+ double avg_sec = total_sec / outer_repeat;
432+ double avg_Tflops = ((double )M) * N * K * 2 * 1e-12 / avg_sec;
433+
434+ printf (" M N K = %6d %6d %6d, W = %1d, R = %2d " , M, N, K, warmup, inner_repeat);
435+ printf (" Time = %12.8lf %12.8lf %12.8lf s, " , min_sec, avg_sec, max_sec);
436+ printf (" AVG Performance = %10.4lf Tflops\n " , avg_Tflops);
437+ }
438+
439+ return 0 ;
440+ }
441+
442+
443+ #else
444+
312445// --------------------- PyTorch bindings for custom kernel -----------------------
446+ #include < torch/types.h>
447+ #include < torch/extension.h>
313448#define STRINGFY (str ) #str
314449#define TORCH_BINDING_COMMON_EXTENSION (func ) \
315450 m.def(STRINGFY(func), &func, STRINGFY(func));
@@ -398,7 +533,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(
398533 constexpr int WARP_TILE_M = 4 ;
399534 constexpr int WARP_TILE_N = 4 ;
400535 constexpr int A_PAD = 0 ; // 0,8,16
401- constexpr int B_PAD = 0 ; // 0,8,16
536+ constexpr int B_PAD = 8 ; // 0,8,16
402537 constexpr int NUM_THREADS= (
403538 MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
404539 constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
@@ -446,3 +581,5 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(
446581 }
447582 }
448583}
584+
585+ #endif
0 commit comments