Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Fail fast on empty query for BatchPrefillWithPagedKVCacheKernel #377

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1484,22 +1484,28 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
static_assert(sizeof(DTypeOut) == 2);
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));
auto block = cg::this_thread_block();
const uint32_t kv_chunk_size = *kv_chunk_size_ptr;

const uint32_t bx = blockIdx.x, lane_idx = threadIdx.x,
warp_idx = get_warp_idx<num_warps_x, num_warps_z>(), kv_head_idx = blockIdx.z;
if (block_valid_mask && !block_valid_mask[bx]) {
return;
}
const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size;
float alibi_slopes[num_frags_x][2];

const uint32_t request_idx = request_indices[bx], qo_tile_idx = q_tile_indices[bx],
kv_tile_idx = kv_tile_indices[bx];
const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx];

if (qo_len == 0) {
// Fail fast if query is empty. May happen with CUDA graphs.
return;
}

const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size;
float alibi_slopes[num_frags_x][2];

constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16;
const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx],
kv_len = (paged_kv.indptr[request_idx + 1] != paged_kv.indptr[request_idx])
const uint32_t kv_len = (paged_kv.indptr[request_idx + 1] != paged_kv.indptr[request_idx])
? (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] -
1) * paged_kv.page_size +
paged_kv.last_page_len[request_idx]
Expand All @@ -1514,6 +1520,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DTypeIn>();
constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b<DTypeOut>();

auto block = cg::this_thread_block();
extern __shared__ uint8_t smem[];

DTypeQKAccum s_frag[num_frags_x][num_frags_z][8];
Expand Down