Skip to content

Conversation

@fsx950223
Copy link
Contributor

@fsx950223 fsx950223 commented Jan 14, 2026

Motivation

KV_BLOCK_SIZE=1024 is supported by the cache layout, but the PS (partitioned-softmax) decode path previously assumed smaller KV block sizes and could:

  • Produce incorrect results / NaNs for block_size=1024
  • Hit GPU memory access faults when sliding_window>0
  • Fail to compile the PS reduce kernel for large context_partition_num due to Triton tensor size limits

This PR makes PS decode robust for KV_BLOCK_SIZE=1024 and fixes PS reduction compilation/resource issues.

Technical Details

1) paged_attention_decode_sliding_window: add KV_BLOCK_SIZE=1024 support

  • Allow KV_BLOCK_SIZE in [16, 64, 1024].
  • For KV_BLOCK_SIZE==1024, treat the KV page as 4 tiles of 256 tokens:
    • KV_COMPUTE_BLOCK_SIZE = CONTEXT_PARTITION_SIZE (256)
    • Compute a per-partition page_offset ∈ {0, 256, 512, 768} and apply it to:
      • key/value loads
      • per-token KV scale loads
  • Use runtime stride_key_block_elem when stepping through KV elements to match the actual key cache layout.

2) PS wrapper fixes

  • Correctly set one-shot mode for PS decode:
    • pass ONE_SHOT=(num_splits <= 1) into paged_attention_decode_sliding_window
    • fixes crashes/incorrect behavior when only one split is used.
  • Tune launch parameters for stability/perf:
    • KV_BLOCK_SIZE==1024: waves_per_eu=1
    • otherwise: waves_per_eu=4
    • use num_stages=1

3) PS reduce kernel: avoid Triton numel limit and shared memory overflow

  • paged_attention_decode_ps_reduce_kernel now reduces partitions in chunks (two-pass reduction), instead of materializing tensors sized by next_power_of_2(context_partition_num).
  • Cap the chunk size to <= 8 partitions:
    • avoids ValueError('numel (...) exceeds triton maximum tensor numel (1048576)')
    • avoids shared-memory overflow for common configs (e.g. qg=64, head=128).

Test Plan

  • op_tests/triton_tests/test_pa_decode_gluon.py:
    • block_size=1024, context_partition_size=256, kv_varlen=True, trans_v=False
    • verify sliding_window=0 and sliding_window=128
    • verify batch_size=1 and batch_size=128
  • Regression sanity:
    • spot-check PS path with block_size=16 using same harness.

Test Result

  • All above tests passed locally.

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request fixes sliding window attention with Multi-Token Processing (MTP) in the paged attention decode implementation, adding support for KV_BLOCK_SIZE=1024 and improving the sliding window causal masking logic.

Changes:

  • Added support for KV_BLOCK_SIZE=1024 in sliding window kernels with appropriate page offset calculations and windowing masks
  • Fixed causal masking for sliding window to correctly handle per-query-position windows
  • Reorganized kernel code for better performance by moving initialization earlier and consolidating the PS path
  • Reduced MAX_CONTEXT_PARTITION_NUM from 16 to 8 to avoid exceeding shared memory limits
  • Expanded test coverage for sliding window scenarios

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
op_tests/triton_tests/test_pa_decode_gluon.py Tightened diff tolerance from 8e-2 to 5e-2 and expanded test coverage with additional head dimensions, quantization modes, and configurations
aiter/ops/triton/gluon/pa_decode_gluon.py Added KV_BLOCK_SIZE=1024 support with page offset handling, fixed sliding window causal masking, reorganized initialization code, reduced MAX_CONTEXT_PARTITION_NUM to 8, and moved PS kernel path to top of wrapper

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants