Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
142 commits
Select commit Hold shift + click to select a range
fa2c2d2
add num_kv_splits_indptr to mla for mtp<=4 case for now
valarLip Jun 26, 2025
15f6155
update
valarLip Jun 27, 2025
8dd5617
update new kernel
valarLip Jul 1, 2025
c871e8d
infrastructures
ruanjm Jul 14, 2025
3750b5f
1st version of split kernel
ruanjm Jul 16, 2025
7ca2598
Fix issues raised by Lingpeng and fix the issue on batch_size
ruanjm Jul 16, 2025
7c5891c
update mla
valarLip Jul 16, 2025
12def78
update mla_stage2
valarLip Jul 18, 2025
5dc5a6d
Merge branch 'main' into mla_splitkv_enhance
valarLip Jul 18, 2025
eae14ae
Merge branch 'main' into mla_splitkv_enhance
valarLip Jul 18, 2025
f244f11
Merge branch 'mla_splitkv_enhance' into jruan/mla_splitkv_enhance_spl…
ruanjm Jul 22, 2025
224f89f
1st draft of v1 split program
ruanjm Jul 22, 2025
ef442fd
add kv_offset
ruanjm Jul 28, 2025
f10235e
mla_splitkv_enhance_split_alg_inte
Zzz9990 Jul 29, 2025
600b5dd
splitkv debug
Zzz9990 Jul 29, 2025
5c58ae8
1st version of reduce kernel
ruanjm Jul 29, 2025
9700bc5
metadata & kernel finish
Zzz9990 Jul 30, 2025
4a86304
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
d49c0cd
add reduce
Zzz9990 Jul 30, 2025
e4bf891
final_lse is optional now.
ruanjm Jul 30, 2025
7bf6aa4
update kernel
Zzz9990 Jul 30, 2025
2411f1f
bug fix
ruanjm Jul 30, 2025
e21600d
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
ffcc113
bug fix 1
ruanjm Jul 30, 2025
07e4ed1
modify reduce api
Zzz9990 Jul 30, 2025
3f2bf25
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
7c877c4
update kernel
Zzz9990 Jul 30, 2025
d10cdab
fix max splits
Zzz9990 Jul 30, 2025
bac5750
bug fix 3
ruanjm Jul 30, 2025
f59a3e6
fix s80 early return
Zzz9990 Jul 30, 2025
1ae58d1
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
5680c26
udpate calculation of partial_indx
ruanjm Jul 30, 2025
fa87c91
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 31, 2025
0dad74c
add per split test
Zzz9990 Jul 31, 2025
a8fa0b1
make lse support by ref
ruanjm Jul 31, 2025
56e964f
test split
Zzz9990 Jul 31, 2025
a76610a
fix redundant calculation of head offset in reduce kernel
ruanjm Jul 31, 2025
4ffd393
add custom test
Zzz9990 Jul 31, 2025
b3747df
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 31, 2025
ba36541
Add support of 128 head size
ruanjm Jul 31, 2025
e5a1b17
update comments
ruanjm Aug 1, 2025
a68879c
1. Let large work be assigned first.
ruanjm Aug 1, 2025
7209c36
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
4494b36
Calculate kv_limit dynamically
ruanjm Aug 4, 2025
09c4ca8
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
1e5e71a
Fix bug about difference in split_kv(bool)
ruanjm Aug 4, 2025
f35cf04
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
f7cf2b9
add test
Zzz9990 Aug 5, 2025
5b91267
fix seed
Zzz9990 Aug 5, 2025
59af206
Add global tolerance 16 in kv seqlen because main kernel cannot handl…
ruanjm Aug 5, 2025
e1b9065
Fix warp=1 error
ruanjm Aug 8, 2025
2adf050
Add redundant mode to make the size of output of metadata be fixed ad…
ruanjm Aug 8, 2025
c0df46b
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 12, 2025
fbff664
fp8 setup
Zzz9990 Aug 12, 2025
1d36311
first version of device metadata
ruanjm Aug 12, 2025
4212a41
Add work_ptrs
ruanjm Aug 12, 2025
818229e
Compatibility to CUDA Graph
ruanjm Aug 13, 2025
704324a
Refactor code. Merge 2 iterations of generate work together.
ruanjm Aug 14, 2025
6be798a
Make sure that each batch of workload can never be splited to more th…
ruanjm Aug 14, 2025
1b0e26f
Adjust metadata. Get 1% perf gain.
ruanjm Aug 14, 2025
36e9b53
Paralize most of metadata kernel
ruanjm Aug 15, 2025
4403c82
add scale
Zzz9990 Aug 18, 2025
fcb36f0
1. Use warp-level bitonic sort to sort batch idx based on their cost …
ruanjm Aug 18, 2025
5dc1eb7
fp8 function pass
Zzz9990 Aug 19, 2025
b46a8e3
Fix issues:
ruanjm Aug 19, 2025
d8d92bc
fp8 ready
Zzz9990 Aug 19, 2025
ead163a
fix
Zzz9990 Aug 19, 2025
7fefc29
Merge remote-tracking branch 'origin/jruan/mla_splitkv_enhance_split_…
Zzz9990 Aug 19, 2025
cc7ffdc
persistent ready
Zzz9990 Aug 19, 2025
5e32d5d
add nv acc test
Zzz9990 Aug 21, 2025
a97fcf8
rename
Zzz9990 Sep 1, 2025
e0c72f8
updata metashape
Zzz9990 Sep 1, 2025
7220b04
update reduce cu num
Zzz9990 Sep 1, 2025
07bf6bb
update optest for mla
Zzz9990 Sep 1, 2025
3a7bd04
fix cu num
Zzz9990 Sep 1, 2025
88c8a0d
Update metadata and reduce kernels.
ruanjm Sep 1, 2025
7f86b0b
rename kernels
Zzz9990 Sep 1, 2025
f8451c1
triton mla ps no-causal ready
Zzz9990 Sep 4, 2025
caaf4b1
triton mla ps ready
Zzz9990 Sep 5, 2025
8bf9c8c
mla ps ready gluon setup
Zzz9990 Sep 5, 2025
980c627
batch 1 ready
Zzz9990 Sep 11, 2025
2e62d0c
fix kwidth
borontion Sep 11, 2025
5c37d7d
Merge pull request #992 from ROCm/zan_triton_mla_ps_fix
Zzz9990 Sep 12, 2025
e79ea19
fix spill
Zzz9990 Sep 12, 2025
6faec4b
fix perf
Zzz9990 Sep 12, 2025
1cbabda
fix vgpr spill
Zzz9990 Sep 15, 2025
eeb0702
move mla gluon
Zzz9990 Sep 15, 2025
649ee9b
prefetch
Zzz9990 Sep 15, 2025
3d02b32
update metrics
Zzz9990 Sep 15, 2025
a6ba090
update gluon mla without ps
Zzz9990 Sep 17, 2025
1936163
fix perf
Zzz9990 Sep 17, 2025
b2d4be7
function ready
Zzz9990 Sep 18, 2025
623523d
updata
Zzz9990 Sep 24, 2025
bc7fba8
fix perf
Zzz9990 Sep 26, 2025
80f0605
save temps
Zzz9990 Oct 13, 2025
5d8e343
update kernelg
Zzz9990 Oct 22, 2025
db2f5a8
fix offset
Zzz9990 Oct 22, 2025
add32f5
input_helper() used in test_mla_decode.py is not working as expected.
ruanjm Oct 23, 2025
48775e9
opt layout to reduce the number of v_perm
jayzlee147 Oct 23, 2025
d0e50cb
Fix regression in bench_mla_decode_without_rope.py
ruanjm Oct 23, 2025
30bb9dd
update prefetch k
Zzz9990 Oct 27, 2025
38c8d65
save temps
Zzz9990 Oct 28, 2025
26f1cfb
temps
Zzz9990 Nov 10, 2025
2e3da58
add paged mla fp8
Zzz9990 Nov 13, 2025
09a6b3d
mi355 setup
Nov 14, 2025
4e79147
functional complete
Zzz9990 Nov 14, 2025
b3b5398
Merge branch 'zan_triton_mla_ps' into zan_triton_mla_ps_mi355
Zzz9990 Nov 14, 2025
0b375a2
update mi355 bf16 mla
Zzz9990 Nov 14, 2025
5ba8fd9
temps
Nov 14, 2025
100da73
add aot to mla
Zzz9990 Nov 18, 2025
3795e5d
Merge branch 'main' into zan_triton_mla_ps_mi355
Zzz9990 Nov 19, 2025
668c839
update aot
Zzz9990 Nov 19, 2025
ab4375d
update
Zzz9990 Nov 19, 2025
36f7ac3
aot ready
Zzz9990 Nov 19, 2025
c20ea5e
remove cpu cdiv
Zzz9990 Nov 19, 2025
89e2f8e
update aot
Zzz9990 Nov 20, 2025
135b0c7
enable aot
Zzz9990 Nov 21, 2025
d1d782e
update aot dispatch
Zzz9990 Nov 21, 2025
58be9ce
update
Zzz9990 Nov 21, 2025
206e194
update mi355 bf16 kernel in aot
Zzz9990 Nov 21, 2025
3e9afe2
update reduce
Zzz9990 Nov 25, 2025
49adc76
update mi300 bf16 gluon mla
Zzz9990 Nov 25, 2025
963d32e
update varlen ut
Zzz9990 Nov 25, 2025
16cf0f6
fix bf16 accuracy
Zzz9990 Nov 25, 2025
c880886
update 355 bf16 kernle
Nov 25, 2025
0a7d6f7
remote gl.cdiv
Zzz9990 Nov 26, 2025
71cb63b
update bf16 kernel on the mi308
Zzz9990 Nov 26, 2025
48555f3
update async mi355
Zzz9990 Nov 26, 2025
028e563
update async ut
Zzz9990 Nov 27, 2025
cfec1cb
update async mla
Zzz9990 Nov 28, 2025
eb79052
test on mi355
Zzz9990 Nov 28, 2025
c3b1861
mi355 debug
Zzz9990 Nov 28, 2025
8342b06
update kernels
Zzz9990 Nov 28, 2025
1b70542
row tile to col tile
Zzz9990 Nov 28, 2025
e2ec063
update mi355 bf16 mla
Zzz9990 Nov 28, 2025
bef7f74
update
Zzz9990 Dec 2, 2025
3496b6e
update
Zzz9990 Dec 26, 2025
5457e28
temp
Jan 5, 2026
4ede5b9
update
Jan 6, 2026
75d8fda
remove mask
Jan 6, 2026
ba06510
Merge branch 'main' into zan_triton_mla_ps_mi355
Zzz9990 Jan 19, 2026
9d3f8b1
Merge branch 'main' into zan_triton_mla_ps_mi355
Zzz9990 Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions aiter/ops/triton/_triton_kernels/mla_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

# Copyright (C) 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for decoding.
It supports page size = 1.
"""

# Adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py

from typing import Optional
import functools
Comment on lines +26 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
typing.Optional imported but unused

Suggested change
from typing import Optional
import functools
import functools

import json
Comment on lines +27 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
functools imported but unused

Suggested change
import functools
import json
import json

import triton
Comment on lines +28 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
json imported but unused

Suggested change
import json
import triton
import triton

import triton.language as tl
import torch
import aiter.ops.triton.utils._triton.arch_info as arch_info
Comment on lines +31 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
torch imported but unused

Suggested change
import torch
import aiter.ops.triton.utils._triton.arch_info as arch_info
import aiter.ops.triton.utils._triton.arch_info as arch_info

from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
Comment on lines +32 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
aiter.ops.triton.utils._triton.arch_info imported but unused

Suggested change
import aiter.ops.triton.utils._triton.arch_info as arch_info
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH

from aiter.ops.triton.utils._triton.pid_preprocessing import remap_xcd
Comment on lines +33 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
aiter.ops.triton.utils.core.AITER_TRITON_CONFIGS_PATH imported but unused

Suggested change
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
from aiter.ops.triton.utils._triton.pid_preprocessing import remap_xcd
from aiter.ops.triton.utils._triton.pid_preprocessing import remap_xcd

from aiter import dtypes
Comment on lines +34 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
aiter.ops.triton.utils._triton.pid_preprocessing.remap_xcd imported but unused

Suggested change
from aiter.ops.triton.utils._triton.pid_preprocessing import remap_xcd
from aiter import dtypes
from aiter import dtypes


Comment on lines +35 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
aiter.dtypes imported but unused

Suggested change
from aiter import dtypes

@triton.jit
def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
Q, # Holds [Q_NOPE; Q_PE], b x h x (d+r)
K_Buffer, # Holds [KV; K_PE], b*s x (c+r)
V_buffer, # Holds [KV], b*s x (c)
sm_scale,
kv_indptr,
kv_indices,
Att_Out, # b x h x NUM_KV_SPLITS x (kv_lora_rank)
Att_Lse, # b x h x NUM_KV_SPLITS x (1)
stride_qb,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_mid_lse_b,
stride_mid_lse_h,
stride_mid_lse_s,
stride_b_block_table,
dummyPointerArg,
kv_lora_rank: tl.constexpr,
qk_rope_head_dim: tl.constexpr,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
batch: tl.constexpr,
logit_cap: tl.constexpr,
max_qo_len: tl.constexpr,
BLOCK_C: tl.constexpr,
BLOCK_R: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_BLOCK_SIZE: tl.constexpr,
):
pass


@triton.jit
def _fwd_grouped_kernel_stage1_n16x2_prefetch_k_paged_64(
Q, # Holds [Q_NOPE; Q_PE], b x h x (d+r)
K_Buffer, # Holds [KV; K_PE], b*s x (c+r)
V_buffer, # Holds [KV], b*s x (c)
sm_scale,
kv_indptr,
kv_indices,
Att_Out, # b x h x NUM_KV_SPLITS x (kv_lora_rank)
Att_Lse, # b x h x NUM_KV_SPLITS x (1)
stride_qb,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_mid_lse_b,
stride_mid_lse_h,
stride_mid_lse_s,
stride_b_block_table,
dummyPointerArg,
kv_lora_rank: tl.constexpr,
qk_rope_head_dim: tl.constexpr,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
batch: tl.constexpr,
logit_cap: tl.constexpr,
max_qo_len: tl.constexpr,
BLOCK_C: tl.constexpr,
BLOCK_R: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_BLOCK_SIZE: tl.constexpr,
):
pass
21 changes: 19 additions & 2 deletions aiter/ops/triton/configs/gfx942-MLA_DECODE_ROPE-DEFAULT.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
{
"fwd_grouped_kernel_stage1_rope_ps": {
"BLOCK_N": 16,
"BLOCK_H": 16,
"num_stages": 1,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 2
},
"fwd_grouped_kernel_stage1_rope": {
"BLOCK_N": 32,
"BLOCK_H": 16,
"num_stages": 1,
"waves_per_eu": 1,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"fwd_grouped_kernel_stage1_rope_fp8": {
"BLOCK_N": 64,
"BLOCK_H": 16,
"num_stages": 1,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
Expand All @@ -13,4 +30,4 @@
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
}
17 changes: 12 additions & 5 deletions aiter/ops/triton/configs/gfx950-MLA_DECODE_ROPE-DEFAULT.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
{
"fwd_grouped_kernel_stage1_rope": {
"BLOCK_N": 32,
"BLOCK_N": 64,
"BLOCK_H": 16,
"num_warps": 8,
"num_stages": 3,
"waves_per_eu": 0,
"num_stages": 1,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"fwd_grouped_kernel_stage1_rope_fp8": {
"BLOCK_N": 64,
"BLOCK_H": 8,
"num_stages": 1,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"fwd_kernel_stage2": {
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"child_paths": {"_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.source": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.source", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.ttgir": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.ttgir", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.llir": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.llir", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.amdgcn": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.amdgcn", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.hsaco": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.hsaco", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.json": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.json"}}
Loading
Loading