Skip to content

Conversation

@yichiche
Copy link
Contributor

@yichiche yichiche commented Jan 18, 2026

Motivation

The gemm_afp4wfp4 Triton kernel currently assumes the K dimension is aligned such that the FP4 scale groups for A/B can be loaded without masking. When K is not aligned to the expected scale grouping, the kernel may read out-of-bounds scale values, leading to incorrect dequantization (and potentially invalid memory accesses depending on layout).

This PR adds masked scale loads for the non-aligned K case to ensure correctness while preserving the fast path when K is aligned.

Technical Details

What changed
In aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp4wfp4.py, the kernel now:

  • Computes the total number of scale groups:
    • num_scale_groups = K * 2 // SCALE_GROUP_SIZE (*2 because FP4 packs two values per uint8)
  • Defines a local scale-group offset per K-iteration:
    • offs_ks_local = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
  • Adds a correctness path for EVEN_K == False:
    • Computes the remaining valid scale-group count for the current iteration:
      • current_scale_offset = k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE)
    • Builds a mask to prevent out-of-bounds scale reads:
      • scale_k_mask = offs_ks_local < (num_scale_groups - current_scale_offset)
    • Uses tl.load(..., mask=..., other=0) for both A/B scales

Why this is safe / performant

  • Aligned K case (EVEN_K) remains unchanged and uses unmasked loads (no perf regression expected there).
  • Non-aligned K case introduces masking only for the scale loads; out-of-range scales are set to 0, which is the safest neutral value for dequantization in the invalid region already zero-masked by the main K-dimension load.

Test Plan

With this modification, the previously failing configuration ("BLOCK_SIZE_K": 512 in aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4-N=7168-K=2304.json) now passes at conc = 64 (IL = 1024 / OL = 1024) without segmentation faults in the last 8 batches.

Test Result

Submission Checklist

@yichiche yichiche changed the base branch from mla_fix_mtp to main January 18, 2026 14:15
@yichiche yichiche requested a review from a team January 18, 2026 14:15
@yichiche yichiche marked this pull request as draft January 18, 2026 14:52
@yichiche yichiche marked this pull request as ready for review January 18, 2026 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant