Fix: Handle non-aligned K dimension for scale loading in gemm_afp4wfp4 kernel #1864
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
The gemm_afp4wfp4 Triton kernel currently assumes the K dimension is aligned such that the FP4 scale groups for A/B can be loaded without masking. When K is not aligned to the expected scale grouping, the kernel may read out-of-bounds scale values, leading to incorrect dequantization (and potentially invalid memory accesses depending on layout).
This PR adds masked scale loads for the non-aligned K case to ensure correctness while preserving the fast path when K is aligned.
Technical Details
What changed
In aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp4wfp4.py, the kernel now:
Why this is safe / performant
Test Plan
With this modification, the previously failing configuration ("BLOCK_SIZE_K": 512 in aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4-N=7168-K=2304.json) now passes at conc = 64 (IL = 1024 / OL = 1024) without segmentation faults in the last 8 batches.
Test Result
Submission Checklist