Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
KV_BLOCK_SIZE=1024is supported by the cache layout, but the PS (partitioned-softmax) decode path previously assumed smaller KV block sizes and could:block_size=1024sliding_window>0context_partition_numdue to Triton tensor size limitsThis PR makes PS decode robust for
KV_BLOCK_SIZE=1024and fixes PS reduction compilation/resource issues.Technical Details
1)
paged_attention_decode_sliding_window: addKV_BLOCK_SIZE=1024supportKV_BLOCK_SIZEin[16, 64, 1024].KV_BLOCK_SIZE==1024, treat the KV page as 4 tiles of 256 tokens:KV_COMPUTE_BLOCK_SIZE = CONTEXT_PARTITION_SIZE (256)page_offset ∈ {0, 256, 512, 768}and apply it to:stride_key_block_elemwhen stepping through KV elements to match the actual key cache layout.2) PS wrapper fixes
ONE_SHOT=(num_splits <= 1)intopaged_attention_decode_sliding_windowKV_BLOCK_SIZE==1024:waves_per_eu=1waves_per_eu=4num_stages=13) PS reduce kernel: avoid Triton
numellimit and shared memory overflowpaged_attention_decode_ps_reduce_kernelnow reduces partitions in chunks (two-pass reduction), instead of materializing tensors sized bynext_power_of_2(context_partition_num).<= 8partitions:ValueError('numel (...) exceeds triton maximum tensor numel (1048576)')qg=64, head=128).Test Plan
op_tests/triton_tests/test_pa_decode_gluon.py:block_size=1024,context_partition_size=256,kv_varlen=True,trans_v=Falsesliding_window=0andsliding_window=128batch_size=1andbatch_size=128block_size=16using same harness.Test Result
Submission Checklist