Skip to content

Commit 4ba91c0

Browse files
authored
Customizable SM90 prefill kernels. (#704)
1 parent 1312409 commit 4ba91c0

15 files changed

+2011
-596
lines changed

aot_build_utils/generate_batch_paged_prefill_sm90_inst.py

+47-37
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
import sys
1919
from pathlib import Path
2020

21-
from .literal_map import (
22-
dtype_literal,
23-
idtype_literal,
24-
mask_mode_literal,
25-
pos_encoding_mode_literal,
26-
)
21+
from .literal_map import dtype_literal, idtype_literal, mask_mode_literal
2722

2823

2924
def get_cu_file_str(
@@ -36,40 +31,56 @@ def get_cu_file_str(
3631
dtype_out,
3732
idtype,
3833
):
34+
pos_encoding_mode = None
35+
allow_fp16_qk_reduction = None
36+
3937
def get_insts(attention_variant):
40-
return "\n".join(
41-
[
42-
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
43-
Params& params,
44-
cudaStream_t stream);
45-
46-
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
47-
Params& params,
48-
cudaStream_t stream);
49-
50-
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
51-
Params& params,
52-
cudaStream_t stream);
53-
54-
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
55-
Params& params,
56-
cudaStream_t stream);
38+
return """
39+
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
40+
<{head_dim},
41+
{mask_mode},
42+
/*USE_SWA=*/true,
43+
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
44+
{attention_variant}>
45+
(Params& params, cudaStream_t stream);
46+
47+
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
48+
<{head_dim},
49+
{mask_mode},
50+
/*USE_SWA=*/true,
51+
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
52+
{attention_variant}>
53+
(Params& params, cudaStream_t stream);
54+
55+
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
56+
<{head_dim},
57+
{mask_mode},
58+
/*USE_SWA=*/false,
59+
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
60+
{attention_variant}>
61+
(Params& params, cudaStream_t stream);
62+
63+
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
64+
<{head_dim},
65+
{mask_mode},
66+
/*USE_SWA=*/false,
67+
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
68+
{attention_variant}>
69+
(Params& params, cudaStream_t stream);
5770
""".format(
58-
head_dim=head_dim,
59-
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
60-
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
61-
mask_mode=mask_mode_literal[int(mask_mode)],
62-
attention_variant=attention_variant,
63-
)
64-
]
71+
head_dim=head_dim,
72+
mask_mode=mask_mode_literal[int(mask_mode)],
73+
attention_variant=attention_variant,
6574
)
6675

6776
dtype_q = dtype_literal[dtype_q]
6877
dtype_kv = dtype_literal[dtype_kv]
6978
dtype_out = dtype_literal[dtype_out]
7079
idtype = idtype_literal[idtype]
7180

72-
content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
81+
content = f""" // batch_paged_prefill_sm90 template inst
82+
#include <flashinfer/attention/hopper/params.cuh>
83+
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
7384
#include <flashinfer/attention/hopper/variants.cuh>
7485
#include <flashinfer/cutlass_utils.cuh>
7586
@@ -82,9 +93,9 @@ def get_insts(attention_variant):
8293
8394
using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
8495
85-
{get_insts("LogitsSoftCap")}
96+
{get_insts("LogitsSoftCap<Params>")}
8697
87-
{get_insts("StandardAttention")}
98+
{get_insts("StandardAttention<Params>")}
8899
89100
}}"""
90101
return content
@@ -93,12 +104,11 @@ def get_insts(attention_variant):
93104
if __name__ == "__main__":
94105
pattern = (
95106
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
96-
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
107+
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_"
108+
r"dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
97109
)
98110
compiled_pattern = re.compile(pattern)
99111
path = Path(sys.argv[1])
100112
fname = path.name
101113
match = compiled_pattern.match(fname)
102-
103-
with open(path, "w") as f:
104-
f.write(get_cu_file_str(*match.groups()))
114+
path.write_text(get_cu_file_str(*match.groups()))

aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py

+33-24
Original file line numberDiff line numberDiff line change
@@ -38,39 +38,48 @@ def get_cu_file_str(
3838
):
3939

4040
def get_insts(attention_variant):
41-
return "\n".join(
42-
[
43-
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
44-
Params& params,
45-
cudaStream_t stream);
46-
47-
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
48-
Params& params,
49-
cudaStream_t stream);
50-
51-
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
52-
Params& params,
53-
cudaStream_t stream);
54-
55-
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
56-
Params& params,
57-
cudaStream_t stream);
41+
return """
42+
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
43+
<{head_dim},
44+
{mask_mode},
45+
/*USE_SWA=*/true,
46+
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
47+
{attention_variant}>(Params& params, cudaStream_t stream);
48+
49+
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
50+
<{head_dim},
51+
{mask_mode},
52+
/*USE_SWA=*/true,
53+
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
54+
{attention_variant}>(Params& params, cudaStream_t stream);
55+
56+
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
57+
<{head_dim},
58+
{mask_mode},
59+
/*USE_SWA=*/false,
60+
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
61+
{attention_variant}>(Params& params, cudaStream_t stream);
62+
63+
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
64+
<{head_dim},
65+
{mask_mode},
66+
/*USE_SWA=*/false,
67+
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
68+
{attention_variant}>(Params& params, cudaStream_t stream);
5869
""".format(
5970
head_dim=head_dim,
60-
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
61-
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
6271
mask_mode=mask_mode_literal[int(mask_mode)],
6372
attention_variant=attention_variant,
6473
)
65-
]
66-
)
6774

6875
dtype_q = dtype_literal[dtype_q]
6976
dtype_kv = dtype_literal[dtype_kv]
7077
dtype_out = dtype_literal[dtype_out]
7178
idtype = idtype_literal[idtype]
7279

73-
content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
80+
content = f""" // batch_ragged_prefill_sm90 template inst
81+
#include <flashinfer/attention/hopper/params.cuh>
82+
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
7483
#include <flashinfer/attention/hopper/variants.cuh>
7584
#include <flashinfer/cutlass_utils.cuh>
7685
@@ -83,9 +92,9 @@ def get_insts(attention_variant):
8392
8493
using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
8594
86-
{get_insts("LogitsSoftCap")}
95+
{get_insts("LogitsSoftCap<Params>")}
8796
88-
{get_insts("StandardAttention")}
97+
{get_insts("StandardAttention<Params>")}
8998
9099
}}
91100
"""

aot_build_utils/generate_single_prefill_sm90_inst.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def get_cu_file_str(
3030
dtype_kv,
3131
dtype_out,
3232
):
33-
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh>
33+
content = """ // single_prefill_sm90 template inst
34+
#include <flashinfer/attention/hopper/params.cuh>
35+
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
3436
#include <flashinfer/attention/hopper/variants.cuh>
3537
#include <flashinfer/cutlass_utils.cuh>
3638
@@ -42,31 +44,32 @@ def get_cu_file_str(
4244
4345
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
4446
45-
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>(
46-
Params& params,
47-
cudaStream_t stream);
47+
template cudaError_t SinglePrefillWithKVCacheDispatched
48+
<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap<Params>>
49+
(Params& params, cudaStream_t stream);
4850
49-
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>(
50-
Params& params,
51-
cudaStream_t stream);
51+
template cudaError_t SinglePrefillWithKVCacheDispatched
52+
<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap<Params>>
53+
(Params& params, cudaStream_t stream);
5254
53-
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>(
54-
Params& params,
55-
cudaStream_t stream);
55+
template cudaError_t SinglePrefillWithKVCacheDispatched
56+
<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention<Params>>
57+
(Params& params, cudaStream_t stream);
58+
59+
template cudaError_t SinglePrefillWithKVCacheDispatched
60+
<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention<Params>>
61+
(Params& params, cudaStream_t stream);
5662
57-
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>(
58-
Params& params,
59-
cudaStream_t stream);
6063
}}
6164
""".format(
6265
head_dim=head_dim,
63-
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
64-
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
66+
# pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
67+
# allow_fp16_qk_reduction=allow_fp16_qk_reduction,
6568
mask_mode=mask_mode_literal[int(mask_mode)],
6669
dtype_q=dtype_literal[dtype_q],
6770
dtype_kv=dtype_literal[dtype_kv],
6871
dtype_out=dtype_literal[dtype_out],
69-
use_custom_mask="true" if int(mask_mode) == 2 else "false",
72+
# use_custom_mask="true" if int(mask_mode) == 2 else "false",
7073
)
7174
return content
7275

@@ -81,5 +84,4 @@ def get_cu_file_str(
8184
path = Path(sys.argv[1])
8285
fname = path.name
8386
match = compiled_pattern.match(fname)
84-
with open(path, "w") as f:
85-
f.write(get_cu_file_str(*match.groups()))
87+
path.write_text(get_cu_file_str(*match.groups()))

csrc/batch_prefill_sm90.cu

+14-12
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,14 @@
2929
namespace flashinfer {
3030

3131
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
32-
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
33-
typename DTypeKV, typename DTypeO, typename IdType>
34-
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
35-
BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
32+
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant>
33+
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT& params,
34+
cudaStream_t stream);
3635

3736
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
38-
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
39-
typename DTypeKV, typename DTypeO, typename IdType>
40-
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
41-
BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
37+
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant>
38+
cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT& params,
39+
cudaStream_t stream);
4240

4341
} // namespace flashinfer
4442

@@ -110,7 +108,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
110108
using DTypeO = DTypeQ;
111109
using IdType = int32_t;
112110

113-
BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType> params;
111+
using BatchPrefillRaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
112+
BatchPrefillRaggedParams params;
114113

115114
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
116115
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
@@ -160,7 +159,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
160159
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
161160
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
162161
using AttentionVariant =
163-
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
162+
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<BatchPrefillRaggedParams>,
163+
StandardAttention<BatchPrefillRaggedParams>>;
164164
cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched<
165165
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
166166
params, stream);
@@ -220,7 +220,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
220220
using DTypeO = DTypeQ;
221221
using IdType = int32_t;
222222

223-
BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType> params;
223+
using BatchPrefillPagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
224+
BatchPrefillPagedParams params;
224225

225226
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
226227
params.k_ptr = static_cast<DTypeKV*>(paged_k_cache.data_ptr());
@@ -272,7 +273,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
272273
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
273274
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
274275
using AttentionVariant =
275-
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
276+
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<BatchPrefillPagedParams>,
277+
StandardAttention<BatchPrefillPagedParams>>;
276278
cudaError_t status = BatchPrefillWithPagedKVCacheDispatched<
277279
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
278280
params, stream);

csrc/single_prefill_sm90.cu

+6-4
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
namespace flashinfer {
2929

3030
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
31-
typename AttentionVariant, typename DTypeQ, typename DTypeKV, typename DTypeO>
32-
cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>& params,
31+
typename AttentionVariant>
32+
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT& params,
3333
cudaStream_t stream);
3434

3535
} // namespace flashinfer
@@ -59,7 +59,8 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q
5959
using DTypeQ = cutlass_dtype_t<q_type>;
6060
using DTypeKV = DTypeQ;
6161
using DTypeO = DTypeQ;
62-
SinglePrefillParams<DTypeQ, DTypeKV, DTypeO> params;
62+
using SinglePrefillParams = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
63+
SinglePrefillParams params;
6364
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
6465
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
6566
params.v_ptr = static_cast<DTypeKV*>(v.data_ptr());
@@ -96,7 +97,8 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q
9697
return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
9798
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
9899
using AttentionVariant =
99-
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
100+
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<SinglePrefillParams>,
101+
StandardAttention<SinglePrefillParams>>;
100102
cudaError_t status =
101103
SinglePrefillWithKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA, AttentionVariant>(
102104
params, stream);

0 commit comments

Comments
 (0)