-
Notifications
You must be signed in to change notification settings - Fork 180
[MLA] triton mla ps mi355 #1884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fa2c2d2
15f6155
8dd5617
c871e8d
3750b5f
7ca2598
7c5891c
12def78
5dc5a6d
eae14ae
f244f11
224f89f
ef442fd
f10235e
600b5dd
5c58ae8
9700bc5
4a86304
d49c0cd
e4bf891
7bf6aa4
2411f1f
e21600d
ffcc113
07e4ed1
3f2bf25
7c877c4
d10cdab
bac5750
f59a3e6
1ae58d1
5680c26
fa87c91
0dad74c
a8fa0b1
56e964f
a76610a
4ffd393
b3747df
ba36541
e5a1b17
a68879c
7209c36
4494b36
09c4ca8
1e5e71a
f35cf04
f7cf2b9
5b91267
59af206
e1b9065
2adf050
c0df46b
fbff664
1d36311
4212a41
818229e
704324a
6be798a
1b0e26f
36e9b53
4403c82
fcb36f0
5dc1eb7
b46a8e3
d8d92bc
ead163a
7fefc29
cc7ffdc
5e32d5d
a97fcf8
e0c72f8
7220b04
07bf6bb
3a7bd04
88c8a0d
7f86b0b
f8451c1
caaf4b1
8bf9c8c
980c627
2e62d0c
5c37d7d
e79ea19
6faec4b
1cbabda
eeb0702
649ee9b
3d02b32
a6ba090
1936163
b2d4be7
623523d
bc7fba8
80f0605
5d8e343
db2f5a8
add32f5
48775e9
d0e50cb
30bb9dd
38c8d65
26f1cfb
2e3da58
09a6b3d
4e79147
b3b5398
0b375a2
5ba8fd9
100da73
3795e5d
668c839
ab4375d
36f7ac3
c20ea5e
89e2f8e
135b0c7
d1d782e
58be9ce
206e194
3e9afe2
49adc76
963d32e
16cf0f6
c880886
0a7d6f7
71cb63b
48555f3
028e563
cfec1cb
eb79052
c3b1861
8342b06
1b70542
e2ec063
bef7f74
3496b6e
5457e28
4ede5b9
75d8fda
ba06510
9d3f8b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,112 @@ | ||||||||
| # SPDX-License-Identifier: MIT | ||||||||
| # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||||||||
|
|
||||||||
| # Copyright (C) 2023-2025 SGLang Team | ||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||
| # you may not use this file except in compliance with the License. | ||||||||
| # You may obtain a copy of the License at | ||||||||
| # | ||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
| # | ||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
| # See the License for the specific language governing permissions and | ||||||||
| # limitations under the License. | ||||||||
| # ============================================================================== | ||||||||
| """ | ||||||||
| Memory-efficient attention for decoding. | ||||||||
| It supports page size = 1. | ||||||||
| """ | ||||||||
|
|
||||||||
| # Adapted from | ||||||||
| # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py | ||||||||
| # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py | ||||||||
|
|
||||||||
| from typing import Optional | ||||||||
| import functools | ||||||||
| import json | ||||||||
|
Comment on lines
+27
to
+28
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
| import triton | ||||||||
|
Comment on lines
+28
to
+29
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
| import triton.language as tl | ||||||||
| import torch | ||||||||
| import aiter.ops.triton.utils._triton.arch_info as arch_info | ||||||||
|
Comment on lines
+31
to
+32
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
| from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH | ||||||||
|
Comment on lines
+32
to
+33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| from aiter.ops.triton.utils._triton.pid_preprocessing import remap_xcd | ||||||||
|
Comment on lines
+33
to
+34
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| from aiter import dtypes | ||||||||
|
Comment on lines
+34
to
+35
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
|
|
||||||||
|
Comment on lines
+35
to
+36
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
| @triton.jit | ||||||||
| def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64( | ||||||||
| Q, # Holds [Q_NOPE; Q_PE], b x h x (d+r) | ||||||||
| K_Buffer, # Holds [KV; K_PE], b*s x (c+r) | ||||||||
| V_buffer, # Holds [KV], b*s x (c) | ||||||||
| sm_scale, | ||||||||
| kv_indptr, | ||||||||
| kv_indices, | ||||||||
| Att_Out, # b x h x NUM_KV_SPLITS x (kv_lora_rank) | ||||||||
| Att_Lse, # b x h x NUM_KV_SPLITS x (1) | ||||||||
| stride_qb, | ||||||||
| stride_qh, | ||||||||
| stride_buf_kbs, | ||||||||
| stride_buf_kh, | ||||||||
| stride_mid_ob, | ||||||||
| stride_mid_oh, | ||||||||
| stride_mid_os, | ||||||||
| stride_mid_lse_b, | ||||||||
| stride_mid_lse_h, | ||||||||
| stride_mid_lse_s, | ||||||||
| stride_b_block_table, | ||||||||
| dummyPointerArg, | ||||||||
| kv_lora_rank: tl.constexpr, | ||||||||
| qk_rope_head_dim: tl.constexpr, | ||||||||
| kv_group_num: tl.constexpr, | ||||||||
| q_head_num: tl.constexpr, | ||||||||
| batch: tl.constexpr, | ||||||||
| logit_cap: tl.constexpr, | ||||||||
| max_qo_len: tl.constexpr, | ||||||||
| BLOCK_C: tl.constexpr, | ||||||||
| BLOCK_R: tl.constexpr, | ||||||||
| BLOCK_N: tl.constexpr, | ||||||||
| BLOCK_H: tl.constexpr, | ||||||||
| NUM_KV_SPLITS: tl.constexpr, | ||||||||
| PAGE_BLOCK_SIZE: tl.constexpr, | ||||||||
| ): | ||||||||
| pass | ||||||||
|
|
||||||||
|
|
||||||||
| @triton.jit | ||||||||
| def _fwd_grouped_kernel_stage1_n16x2_prefetch_k_paged_64( | ||||||||
| Q, # Holds [Q_NOPE; Q_PE], b x h x (d+r) | ||||||||
| K_Buffer, # Holds [KV; K_PE], b*s x (c+r) | ||||||||
| V_buffer, # Holds [KV], b*s x (c) | ||||||||
| sm_scale, | ||||||||
| kv_indptr, | ||||||||
| kv_indices, | ||||||||
| Att_Out, # b x h x NUM_KV_SPLITS x (kv_lora_rank) | ||||||||
| Att_Lse, # b x h x NUM_KV_SPLITS x (1) | ||||||||
| stride_qb, | ||||||||
| stride_qh, | ||||||||
| stride_buf_kbs, | ||||||||
| stride_buf_kh, | ||||||||
| stride_mid_ob, | ||||||||
| stride_mid_oh, | ||||||||
| stride_mid_os, | ||||||||
| stride_mid_lse_b, | ||||||||
| stride_mid_lse_h, | ||||||||
| stride_mid_lse_s, | ||||||||
| stride_b_block_table, | ||||||||
| dummyPointerArg, | ||||||||
| kv_lora_rank: tl.constexpr, | ||||||||
| qk_rope_head_dim: tl.constexpr, | ||||||||
| kv_group_num: tl.constexpr, | ||||||||
| q_head_num: tl.constexpr, | ||||||||
| batch: tl.constexpr, | ||||||||
| logit_cap: tl.constexpr, | ||||||||
| max_qo_len: tl.constexpr, | ||||||||
| BLOCK_C: tl.constexpr, | ||||||||
| BLOCK_R: tl.constexpr, | ||||||||
| BLOCK_N: tl.constexpr, | ||||||||
| BLOCK_H: tl.constexpr, | ||||||||
| NUM_KV_SPLITS: tl.constexpr, | ||||||||
| PAGE_BLOCK_SIZE: tl.constexpr, | ||||||||
| ): | ||||||||
| pass | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,24 @@ | ||
| { | ||
| "fwd_grouped_kernel_stage1_rope": { | ||
| "BLOCK_N": 32, | ||
| "BLOCK_N": 64, | ||
| "BLOCK_H": 16, | ||
| "num_warps": 8, | ||
| "num_stages": 3, | ||
| "waves_per_eu": 0, | ||
| "num_stages": 1, | ||
| "waves_per_eu": 1, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 1 | ||
| }, | ||
| "fwd_grouped_kernel_stage1_rope_fp8": { | ||
| "BLOCK_N": 64, | ||
| "BLOCK_H": 8, | ||
| "num_stages": 1, | ||
| "waves_per_eu": 1, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "fwd_kernel_stage2": { | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 1 | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"child_paths": {"_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.source": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.source", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.ttgir": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.ttgir", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.llir": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.llir", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.amdgcn": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.amdgcn", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.hsaco": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.hsaco", "_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.json": "/root/.triton/cache/NO2HNETERT3ZG2JLGKKSSCQC3PYDOIM2KB3CHXEI2AU5VWPWAHZQ/_fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64.json"}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing.Optionalimported but unused