@@ -184,6 +184,88 @@ struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
184
184
scalar_t val[vec_size];
185
185
};
186
186
187
+ template <typename input_t >
188
+ __global__ inline void _get_FP8_qparam_cuda_kernel_vectorized (
189
+ const pta::PackedTensorAccessor64<input_t , 1 , at::RestrictPtrTraits> input,
190
+ const int64_t nrows,
191
+ const int64_t ncols,
192
+ pta::PackedTensorAccessor64<uint8_t , 1 , at::RestrictPtrTraits> output,
193
+ const bool forward) {
194
+ // Assert if index is out of bound
195
+ CUDA_KERNEL_ASSERT (nrows * ncols >= 0 );
196
+ const int64_t row = blockIdx .x * blockDim .y + threadIdx .y ;
197
+
198
+ const int64_t ncols_aligned = (ncols + 4 - 1 ) / 4 * 4 ;
199
+ const int64_t output_columns = ncols_aligned + 2 * sizeof (float );
200
+
201
+ const float max_pos = forward ? 0 .9375f : 0 .875f ;
202
+
203
+ // starting values for future reductions
204
+ constexpr float kEpsilon = 1e-20f ;
205
+ float maximum_element = kEpsilon ;
206
+ // always a power of 2 up to size 32. Multiple rows can share the same warp
207
+ // when smaller than 32.
208
+ const auto lane_width = blockDim .x ;
209
+
210
+ // Using vectorized loads to improve memory throughput
211
+ static constexpr int vec_size = VectorSizeTraits<input_t >::value;
212
+ using vec_t = aligned_vector<input_t , vec_size>;
213
+
214
+ // March warp-wise through the row, doing thread local reductions
215
+ if (row < nrows) {
216
+ const input_t * input_row = &input[row * ncols];
217
+
218
+ // 1. Process chunks of 4 elements at a time where possible
219
+ const int64_t vec_blocks = ncols / vec_size;
220
+ for (int64_t vec_idx = threadIdx .x ; vec_idx < vec_blocks;
221
+ vec_idx += lane_width) {
222
+ const int64_t col_idx = vec_idx * vec_size;
223
+
224
+ // Don't access beyond valid input data
225
+ if (col_idx + vec_size - 1 < ncols) {
226
+ // Load 4 elements at once using vectorized memory access
227
+ const vec_t * vec_input =
228
+ reinterpret_cast <const vec_t *>(&input_row[col_idx]);
229
+
230
+ // Find max absolute value among the vec_size elements
231
+ #pragma unroll
232
+ for (int i = 0 ; i < vec_size; ++i) {
233
+ maximum_element =
234
+ fmaxf (maximum_element, fabs (to_float (vec_input->val [i])));
235
+ }
236
+ }
237
+ }
238
+
239
+ // 2. Process any remaining elements with scalar operations
240
+ const int64_t remaining_start = vec_blocks * vec_size;
241
+ for (int64_t col = remaining_start + threadIdx .x ; col < ncols;
242
+ col += lane_width) {
243
+ maximum_element = fmaxf (maximum_element, fabs (to_float (input_row[col])));
244
+ }
245
+ }
246
+
247
+ // Perform warp-wide reduction. All threads in the warp
248
+ // participate, even if they aren't assigned to a row, since we can't assume
249
+ // the existence of the `*_sync` warp primitives with support for masking.
250
+ for (int offset = lane_width >> 1 ; offset > 0 ; offset >>= 1 ) {
251
+ maximum_element =
252
+ fmaxf (maximum_element, shfl_xor (maximum_element, offset, lane_width));
253
+ }
254
+
255
+ // only the leading thread in the warp is needed to return the final result in
256
+ // output. Additionally, threads mapped to non-existent rows do not write to
257
+ // the output array.
258
+ if (threadIdx .x != 0 || row >= nrows) {
259
+ return ;
260
+ }
261
+ float * const output_row_qparams =
262
+ reinterpret_cast <float *>(&output[row * output_columns + ncols_aligned]);
263
+
264
+ output_row_qparams[0 ] = max_pos / (kEpsilon + maximum_element);
265
+ // Initialize it to make the output deterministic for PT2 compliance
266
+ output_row_qparams[1 ] = 0.0 ;
267
+ }
268
+
187
269
template <typename input_t >
188
270
__global__ inline void _compute_FP8_quantize_cuda_vectorized_kernel (
189
271
const pta::PackedTensorAccessor64<input_t , 1 , at::RestrictPtrTraits> input,
@@ -356,111 +438,101 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
356
438
C10_CUDA_KERNEL_LAUNCH_CHECK ();
357
439
});
358
440
} else {
359
- // range_tensor is used to store the range for each embedding row.
360
- // We save max_pos/max_val(rowwise) as row scale to quantize
361
- // unlike INT8, FP8 does not have zero shift
362
- // This will guarantee the numerical match but bring some perf
363
- // regression.
364
- auto range_tensor = at::empty ({nrows}, input.options ().dtype (at::kFloat ));
365
-
366
- {
367
- // we need a blockDim.x that is a power of 2 no larger than the warp size
368
- // of 32
369
-
370
- int blockDim_x = 1 ;
371
- if (ncols > 16 ) {
372
- // max warp size
373
- blockDim_x = 32 ;
374
- } else {
375
- while (blockDim_x < ncols) {
376
- blockDim_x <<= 1 ;
377
- }
441
+ int blockDim_x = 32 ; // Max warp size for optimal performance
442
+ if (ncols <= 16 ) {
443
+ // For very small column counts, reduce block size to avoid wasted
444
+ // threads
445
+ blockDim_x = 16 ;
446
+ while (blockDim_x > ncols && blockDim_x > 1 ) {
447
+ blockDim_x >>= 1 ;
378
448
}
449
+ }
450
+ const auto rows_per_block = threads_per_block / blockDim_x;
451
+
452
+ int num_sms = 80 ;
453
+ int max_blocks_per_sm = 32 ;
454
+
455
+ // Target enough blocks to saturate the GPU while avoiding excessive
456
+ // overhead
457
+ const int target_blocks = num_sms * max_blocks_per_sm;
458
+ const int num_blocks_warp = std::min (
459
+ static_cast <int >(cuda_calc_xblock_count (nrows, rows_per_block)),
460
+ target_blocks);
461
+
462
+ FBGEMM_DISPATCH_FLOATING_TYPES (
463
+ input.scalar_type (), " _get_FP8_qparam_cuda_kernel_vectorized" , [&] {
464
+ #ifdef FBGEMM_GPU_MEMCHECK
465
+ const auto func_name = " _get_FP8_qparam_cuda_kernel_vectorized" ;
466
+ #endif
467
+ _get_FP8_qparam_cuda_kernel_vectorized<scalar_t >
468
+ <<<num_blocks_warp,
469
+ dim3 (blockDim_x, rows_per_block),
470
+ 0,
471
+ at::cuda::getCurrentCUDAStream()>>>(
472
+ MAKE_PTA_WITH_NAME (func_name, input_1D, scalar_t , 1 , 64 ),
473
+ nrows,
474
+ ncols,
475
+ MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t , 1 , 64 ),
476
+ forward);
477
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
478
+ });
379
479
380
- const auto rows_per_block = threads_per_block / blockDim_x;
381
- const auto num_blocks_warp =
382
- cuda_calc_xblock_count (nrows, rows_per_block);
480
+ if ((ncols % VectorSizeTraits<input_t >::value) != 0 ) {
481
+ const int blockDim_x =
482
+ std::min (ncols, static_cast <int64_t >(threads_per_block));
483
+ dim3 blockDim (blockDim_x, threads_per_block / blockDim_x);
484
+ const auto gridDim_x = cuda_calc_xblock_count (ncols, blockDim .x );
485
+ const auto gridDim_y = cuda_calc_block_count (nrows, blockDim .y );
486
+ dim3 gridDim (gridDim_x, gridDim_y);
383
487
384
488
FBGEMM_DISPATCH_FLOATING_TYPES (
385
- input.scalar_type (), " _get_FP8_qparam_cuda_kernel " , [&] {
489
+ input.scalar_type (), " _compute_FP8_quantize_cuda_kernel " , [&] {
386
490
#ifdef FBGEMM_GPU_MEMCHECK
387
- const auto func_name = " _get_FP8_qparam_cuda_kernel " ;
491
+ const auto func_name = " _compute_FP8_quantize_cuda_kernel " ;
388
492
#endif
389
- _get_FP8_qparam_cuda_kernel<scalar_t >
390
- <<<num_blocks_warp,
391
- dim3 (blockDim_x, rows_per_block),
392
- 0,
393
- at::cuda::getCurrentCUDAStream()>>>(
493
+ _compute_FP8_quantize_cuda_kernel<scalar_t >
494
+ <<<gridDim , blockDim , 0 , at::cuda::getCurrentCUDAStream()>>> (
394
495
MAKE_PTA_WITH_NAME (func_name, input_1D, scalar_t , 1 , 64 ),
395
496
nrows,
396
497
ncols,
397
498
MAKE_PTA_WITH_NAME (func_name, output_1D, uint8_t , 1 , 64 ),
398
499
forward);
399
500
C10_CUDA_KERNEL_LAUNCH_CHECK ();
400
501
});
401
- }
502
+ } else {
503
+ // Simple thread block configuration for the scalar kernel
504
+ // Use 256 threads per block with 32 threads in X dimension (for warp
505
+ // alignment)
506
+ const int BLOCK_DIM_X = 32 ; // Keep warp-aligned for best performance
507
+ const int BLOCK_DIM_Y = 8 ; // Balance between Y coverage and registers
508
+ dim3 blockSize (BLOCK_DIM_X, BLOCK_DIM_Y);
509
+
510
+ int num_sms = 80 ;
511
+ int max_blocks_per_sm = 16 ;
512
+ int target_blocks_x = num_sms * max_blocks_per_sm;
513
+
514
+ int gridX = std::min (
515
+ (int )((ncols + blockSize.x - 1 ) / blockSize.x ), target_blocks_x);
516
+ int gridY = (nrows + blockSize.y - 1 ) / blockSize.y ;
517
+
518
+ gridX = std::min (gridX, 65535 );
519
+ gridY = std::min (gridY, 65535 );
520
+ dim3 gridSize (gridX, gridY);
402
521
403
- {
404
- if ((ncols % VectorSizeTraits<input_t >::value) != 0 ) {
405
- const int blockDim_x =
406
- std::min (ncols, static_cast <int64_t >(threads_per_block));
407
- dim3 blockDim (blockDim_x, threads_per_block / blockDim_x);
408
- const auto gridDim_x = cuda_calc_xblock_count (ncols, blockDim .x );
409
- const auto gridDim_y = cuda_calc_block_count (nrows, blockDim .y );
410
- dim3 gridDim (gridDim_x, gridDim_y);
411
-
412
- FBGEMM_DISPATCH_FLOATING_TYPES (
413
- input.scalar_type (), " _compute_FP8_quantize_cuda_kernel" , [&] {
414
- #ifdef FBGEMM_GPU_MEMCHECK
415
- const auto func_name = " _compute_FP8_quantize_cuda_kernel" ;
416
- #endif
417
- _compute_FP8_quantize_cuda_kernel<scalar_t >
418
- <<<gridDim , blockDim , 0 , at::cuda::getCurrentCUDAStream()>>> (
419
- MAKE_PTA_WITH_NAME (func_name, input_1D, scalar_t , 1 , 64 ),
420
- nrows,
421
- ncols,
422
- MAKE_PTA_WITH_NAME (func_name, output_1D, uint8_t , 1 , 64 ),
423
- forward);
424
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
425
- });
426
- } else {
427
- // Simple thread block configuration for the scalar kernel
428
- // Use 256 threads per block with 32 threads in X dimension (for warp
429
- // alignment)
430
- const int BLOCK_DIM_X = 32 ; // Keep warp-aligned for best performance
431
- const int BLOCK_DIM_Y = 8 ; // Balance between Y coverage and registers
432
- dim3 blockSize (BLOCK_DIM_X, BLOCK_DIM_Y);
433
-
434
- int num_sms = 80 ;
435
- int max_blocks_per_sm = 16 ;
436
- int target_blocks_x = num_sms * max_blocks_per_sm;
437
-
438
- int gridX = std::min (
439
- (int )((ncols + blockSize.x - 1 ) / blockSize.x ), target_blocks_x);
440
- int gridY = (nrows + blockSize.y - 1 ) / blockSize.y ;
441
-
442
- gridX = std::min (gridX, 65535 );
443
- gridY = std::min (gridY, 65535 );
444
- dim3 gridSize (gridX, gridY);
445
-
446
- FBGEMM_DISPATCH_FLOATING_TYPES (
447
- input.scalar_type (), " _compute_FP8_quantize_cuda_kernel" , [&] {
522
+ FBGEMM_DISPATCH_FLOATING_TYPES (
523
+ input.scalar_type (), " _compute_FP8_quantize_cuda_kernel" , [&] {
448
524
#ifdef FBGEMM_GPU_MEMCHECK
449
- const auto func_name = " _compute_FP8_quantize_cuda_kernel" ;
525
+ const auto func_name = " _compute_FP8_quantize_cuda_kernel" ;
450
526
#endif
451
- _compute_FP8_quantize_cuda_vectorized_kernel<scalar_t >
452
- <<<gridSize,
453
- blockSize,
454
- 0 ,
455
- at::cuda::getCurrentCUDAStream ()>>>(
456
- MAKE_PTA_WITH_NAME (func_name, input_1D, scalar_t , 1 , 64 ),
457
- nrows,
458
- ncols,
459
- MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t , 1 , 64 ),
460
- forward);
461
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
462
- });
463
- }
527
+ _compute_FP8_quantize_cuda_vectorized_kernel<scalar_t >
528
+ <<<gridSize, blockSize, 0 , at::cuda::getCurrentCUDAStream()>>> (
529
+ MAKE_PTA_WITH_NAME (func_name, input_1D, scalar_t , 1 , 64 ),
530
+ nrows,
531
+ ncols,
532
+ MAKE_PTA_WITH_NAME (func_name, output_1D, uint8_t , 1 , 64 ),
533
+ forward);
534
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
535
+ });
464
536
}
465
537
}
466
538
0 commit comments