|
29 | 29 | namespace flashinfer {
|
30 | 30 |
|
31 | 31 | template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
|
32 |
| - bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ, |
33 |
| - typename DTypeKV, typename DTypeO, typename IdType> |
34 |
| -cudaError_t BatchPrefillWithRaggedKVCacheDispatched( |
35 |
| - BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream); |
| 32 | + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant> |
| 33 | +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT& params, |
| 34 | + cudaStream_t stream); |
36 | 35 |
|
37 | 36 | template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
|
38 |
| - bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ, |
39 |
| - typename DTypeKV, typename DTypeO, typename IdType> |
40 |
| -cudaError_t BatchPrefillWithPagedKVCacheDispatched( |
41 |
| - BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream); |
| 37 | + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant> |
| 38 | +cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT& params, |
| 39 | + cudaStream_t stream); |
42 | 40 |
|
43 | 41 | } // namespace flashinfer
|
44 | 42 |
|
@@ -110,7 +108,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
|
110 | 108 | using DTypeO = DTypeQ;
|
111 | 109 | using IdType = int32_t;
|
112 | 110 |
|
113 |
| - BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType> params; |
| 111 | + using BatchPrefillRaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; |
| 112 | + BatchPrefillRaggedParams params; |
114 | 113 |
|
115 | 114 | params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
|
116 | 115 | params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
|
@@ -160,7 +159,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
|
160 | 159 | return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
|
161 | 160 | return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
|
162 | 161 | using AttentionVariant =
|
163 |
| - std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>; |
| 162 | + std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<BatchPrefillRaggedParams>, |
| 163 | + StandardAttention<BatchPrefillRaggedParams>>; |
164 | 164 | cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched<
|
165 | 165 | HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
|
166 | 166 | params, stream);
|
@@ -220,7 +220,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
|
220 | 220 | using DTypeO = DTypeQ;
|
221 | 221 | using IdType = int32_t;
|
222 | 222 |
|
223 |
| - BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType> params; |
| 223 | + using BatchPrefillPagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; |
| 224 | + BatchPrefillPagedParams params; |
224 | 225 |
|
225 | 226 | params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
|
226 | 227 | params.k_ptr = static_cast<DTypeKV*>(paged_k_cache.data_ptr());
|
@@ -272,7 +273,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
|
272 | 273 | return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
|
273 | 274 | return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
|
274 | 275 | using AttentionVariant =
|
275 |
| - std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>; |
| 276 | + std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<BatchPrefillPagedParams>, |
| 277 | + StandardAttention<BatchPrefillPagedParams>>; |
276 | 278 | cudaError_t status = BatchPrefillWithPagedKVCacheDispatched<
|
277 | 279 | HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
|
278 | 280 | params, stream);
|
|
0 commit comments