Skip to content

Commit 57d5da3

Browse files
flaviotruzzifacebook-github-bot
authored andcommitted
Vectorize load/store for get_FP8_qparam_cuda_kernel (#4263)
Summary: Similar to D75563906, which vectorized the _compute_FP8_quantize_cuda_kernel, this diff applies the same vectorization technique to the _get_FP8_qparam_cuda_kernel. This optimization improves memory throughput by using the aligned_vector struct to load 4 elements at a time, which should generate vectorized load instructions on the GPU. The original non-vectorized kernel is preserved as a reference, and a new _get_FP8_qparam_cuda_kernel_vectorized is introduced. Differential Revision: D75639684
1 parent e10ce8b commit 57d5da3

File tree

1 file changed

+161
-89
lines changed

1 file changed

+161
-89
lines changed

fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu

Lines changed: 161 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,88 @@ struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
184184
scalar_t val[vec_size];
185185
};
186186

187+
template <typename input_t>
188+
__global__ inline void _get_FP8_qparam_cuda_kernel_vectorized(
189+
const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
190+
const int64_t nrows,
191+
const int64_t ncols,
192+
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> output,
193+
const bool forward) {
194+
// Assert if index is out of bound
195+
CUDA_KERNEL_ASSERT(nrows * ncols >= 0);
196+
const int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
197+
198+
const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4;
199+
const int64_t output_columns = ncols_aligned + 2 * sizeof(float);
200+
201+
const float max_pos = forward ? 0.9375f : 0.875f;
202+
203+
// starting values for future reductions
204+
constexpr float kEpsilon = 1e-20f;
205+
float maximum_element = kEpsilon;
206+
// always a power of 2 up to size 32. Multiple rows can share the same warp
207+
// when smaller than 32.
208+
const auto lane_width = blockDim.x;
209+
210+
// Using vectorized loads to improve memory throughput
211+
static constexpr int vec_size = VectorSizeTraits<input_t>::value;
212+
using vec_t = aligned_vector<input_t, vec_size>;
213+
214+
// March warp-wise through the row, doing thread local reductions
215+
if (row < nrows) {
216+
const input_t* input_row = &input[row * ncols];
217+
218+
// 1. Process chunks of 4 elements at a time where possible
219+
const int64_t vec_blocks = ncols / vec_size;
220+
for (int64_t vec_idx = threadIdx.x; vec_idx < vec_blocks;
221+
vec_idx += lane_width) {
222+
const int64_t col_idx = vec_idx * vec_size;
223+
224+
// Don't access beyond valid input data
225+
if (col_idx + vec_size - 1 < ncols) {
226+
// Load 4 elements at once using vectorized memory access
227+
const vec_t* vec_input =
228+
reinterpret_cast<const vec_t*>(&input_row[col_idx]);
229+
230+
// Find max absolute value among the vec_size elements
231+
#pragma unroll
232+
for (int i = 0; i < vec_size; ++i) {
233+
maximum_element =
234+
fmaxf(maximum_element, fabs(to_float(vec_input->val[i])));
235+
}
236+
}
237+
}
238+
239+
// 2. Process any remaining elements with scalar operations
240+
const int64_t remaining_start = vec_blocks * vec_size;
241+
for (int64_t col = remaining_start + threadIdx.x; col < ncols;
242+
col += lane_width) {
243+
maximum_element = fmaxf(maximum_element, fabs(to_float(input_row[col])));
244+
}
245+
}
246+
247+
// Perform warp-wide reduction. All threads in the warp
248+
// participate, even if they aren't assigned to a row, since we can't assume
249+
// the existence of the `*_sync` warp primitives with support for masking.
250+
for (int offset = lane_width >> 1; offset > 0; offset >>= 1) {
251+
maximum_element =
252+
fmaxf(maximum_element, shfl_xor(maximum_element, offset, lane_width));
253+
}
254+
255+
// only the leading thread in the warp is needed to return the final result in
256+
// output. Additionally, threads mapped to non-existent rows do not write to
257+
// the output array.
258+
if (threadIdx.x != 0 || row >= nrows) {
259+
return;
260+
}
261+
float* const output_row_qparams =
262+
reinterpret_cast<float*>(&output[row * output_columns + ncols_aligned]);
263+
264+
output_row_qparams[0] = max_pos / (kEpsilon + maximum_element);
265+
// Initialize it to make the output deterministic for PT2 compliance
266+
output_row_qparams[1] = 0.0;
267+
}
268+
187269
template <typename input_t>
188270
__global__ inline void _compute_FP8_quantize_cuda_vectorized_kernel(
189271
const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
@@ -356,111 +438,101 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
356438
C10_CUDA_KERNEL_LAUNCH_CHECK();
357439
});
358440
} else {
359-
// range_tensor is used to store the range for each embedding row.
360-
// We save max_pos/max_val(rowwise) as row scale to quantize
361-
// unlike INT8, FP8 does not have zero shift
362-
// This will guarantee the numerical match but bring some perf
363-
// regression.
364-
auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat));
365-
366-
{
367-
// we need a blockDim.x that is a power of 2 no larger than the warp size
368-
// of 32
369-
370-
int blockDim_x = 1;
371-
if (ncols > 16) {
372-
// max warp size
373-
blockDim_x = 32;
374-
} else {
375-
while (blockDim_x < ncols) {
376-
blockDim_x <<= 1;
377-
}
441+
int blockDim_x = 32; // Max warp size for optimal performance
442+
if (ncols <= 16) {
443+
// For very small column counts, reduce block size to avoid wasted
444+
// threads
445+
blockDim_x = 16;
446+
while (blockDim_x > ncols && blockDim_x > 1) {
447+
blockDim_x >>= 1;
378448
}
449+
}
450+
const auto rows_per_block = threads_per_block / blockDim_x;
451+
452+
int num_sms = 80;
453+
int max_blocks_per_sm = 32;
454+
455+
// Target enough blocks to saturate the GPU while avoiding excessive
456+
// overhead
457+
const int target_blocks = num_sms * max_blocks_per_sm;
458+
const int num_blocks_warp = std::min(
459+
static_cast<int>(cuda_calc_xblock_count(nrows, rows_per_block)),
460+
target_blocks);
461+
462+
FBGEMM_DISPATCH_FLOATING_TYPES(
463+
input.scalar_type(), "_get_FP8_qparam_cuda_kernel_vectorized", [&] {
464+
#ifdef FBGEMM_GPU_MEMCHECK
465+
const auto func_name = "_get_FP8_qparam_cuda_kernel_vectorized";
466+
#endif
467+
_get_FP8_qparam_cuda_kernel_vectorized<scalar_t>
468+
<<<num_blocks_warp,
469+
dim3(blockDim_x, rows_per_block),
470+
0,
471+
at::cuda::getCurrentCUDAStream()>>>(
472+
MAKE_PTA_WITH_NAME(func_name, input_1D, scalar_t, 1, 64),
473+
nrows,
474+
ncols,
475+
MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t, 1, 64),
476+
forward);
477+
C10_CUDA_KERNEL_LAUNCH_CHECK();
478+
});
379479
380-
const auto rows_per_block = threads_per_block / blockDim_x;
381-
const auto num_blocks_warp =
382-
cuda_calc_xblock_count(nrows, rows_per_block);
480+
if ((ncols % VectorSizeTraits<input_t>::value) != 0) {
481+
const int blockDim_x =
482+
std::min(ncols, static_cast<int64_t>(threads_per_block));
483+
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
484+
const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x);
485+
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
486+
dim3 gridDim(gridDim_x, gridDim_y);
383487
384488
FBGEMM_DISPATCH_FLOATING_TYPES(
385-
input.scalar_type(), "_get_FP8_qparam_cuda_kernel", [&] {
489+
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
386490
#ifdef FBGEMM_GPU_MEMCHECK
387-
const auto func_name = "_get_FP8_qparam_cuda_kernel";
491+
const auto func_name = "_compute_FP8_quantize_cuda_kernel";
388492
#endif
389-
_get_FP8_qparam_cuda_kernel<scalar_t>
390-
<<<num_blocks_warp,
391-
dim3(blockDim_x, rows_per_block),
392-
0,
393-
at::cuda::getCurrentCUDAStream()>>>(
493+
_compute_FP8_quantize_cuda_kernel<scalar_t>
494+
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
394495
MAKE_PTA_WITH_NAME(func_name, input_1D, scalar_t, 1, 64),
395496
nrows,
396497
ncols,
397498
MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t, 1, 64),
398499
forward);
399500
C10_CUDA_KERNEL_LAUNCH_CHECK();
400501
});
401-
}
502+
} else {
503+
// Simple thread block configuration for the scalar kernel
504+
// Use 256 threads per block with 32 threads in X dimension (for warp
505+
// alignment)
506+
const int BLOCK_DIM_X = 32; // Keep warp-aligned for best performance
507+
const int BLOCK_DIM_Y = 8; // Balance between Y coverage and registers
508+
dim3 blockSize(BLOCK_DIM_X, BLOCK_DIM_Y);
509+
510+
int num_sms = 80;
511+
int max_blocks_per_sm = 16;
512+
int target_blocks_x = num_sms * max_blocks_per_sm;
513+
514+
int gridX = std::min(
515+
(int)((ncols + blockSize.x - 1) / blockSize.x), target_blocks_x);
516+
int gridY = (nrows + blockSize.y - 1) / blockSize.y;
517+
518+
gridX = std::min(gridX, 65535);
519+
gridY = std::min(gridY, 65535);
520+
dim3 gridSize(gridX, gridY);
402521
403-
{
404-
if ((ncols % VectorSizeTraits<input_t>::value) != 0) {
405-
const int blockDim_x =
406-
std::min(ncols, static_cast<int64_t>(threads_per_block));
407-
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
408-
const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x);
409-
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
410-
dim3 gridDim(gridDim_x, gridDim_y);
411-
412-
FBGEMM_DISPATCH_FLOATING_TYPES(
413-
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
414-
#ifdef FBGEMM_GPU_MEMCHECK
415-
const auto func_name = "_compute_FP8_quantize_cuda_kernel";
416-
#endif
417-
_compute_FP8_quantize_cuda_kernel<scalar_t>
418-
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
419-
MAKE_PTA_WITH_NAME(func_name, input_1D, scalar_t, 1, 64),
420-
nrows,
421-
ncols,
422-
MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t, 1, 64),
423-
forward);
424-
C10_CUDA_KERNEL_LAUNCH_CHECK();
425-
});
426-
} else {
427-
// Simple thread block configuration for the scalar kernel
428-
// Use 256 threads per block with 32 threads in X dimension (for warp
429-
// alignment)
430-
const int BLOCK_DIM_X = 32; // Keep warp-aligned for best performance
431-
const int BLOCK_DIM_Y = 8; // Balance between Y coverage and registers
432-
dim3 blockSize(BLOCK_DIM_X, BLOCK_DIM_Y);
433-
434-
int num_sms = 80;
435-
int max_blocks_per_sm = 16;
436-
int target_blocks_x = num_sms * max_blocks_per_sm;
437-
438-
int gridX = std::min(
439-
(int)((ncols + blockSize.x - 1) / blockSize.x), target_blocks_x);
440-
int gridY = (nrows + blockSize.y - 1) / blockSize.y;
441-
442-
gridX = std::min(gridX, 65535);
443-
gridY = std::min(gridY, 65535);
444-
dim3 gridSize(gridX, gridY);
445-
446-
FBGEMM_DISPATCH_FLOATING_TYPES(
447-
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
522+
FBGEMM_DISPATCH_FLOATING_TYPES(
523+
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
448524
#ifdef FBGEMM_GPU_MEMCHECK
449-
const auto func_name = "_compute_FP8_quantize_cuda_kernel";
525+
const auto func_name = "_compute_FP8_quantize_cuda_kernel";
450526
#endif
451-
_compute_FP8_quantize_cuda_vectorized_kernel<scalar_t>
452-
<<<gridSize,
453-
blockSize,
454-
0,
455-
at::cuda::getCurrentCUDAStream()>>>(
456-
MAKE_PTA_WITH_NAME(func_name, input_1D, scalar_t, 1, 64),
457-
nrows,
458-
ncols,
459-
MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t, 1, 64),
460-
forward);
461-
C10_CUDA_KERNEL_LAUNCH_CHECK();
462-
});
463-
}
527+
_compute_FP8_quantize_cuda_vectorized_kernel<scalar_t>
528+
<<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
529+
MAKE_PTA_WITH_NAME(func_name, input_1D, scalar_t, 1, 64),
530+
nrows,
531+
ncols,
532+
MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t, 1, 64),
533+
forward);
534+
C10_CUDA_KERNEL_LAUNCH_CHECK();
535+
});
464536
}
465537
}
466538

0 commit comments

Comments
 (0)