Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : chunked prefill support #10718

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2418,6 +2418,14 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);

// there are currently slots with ongoing text generation
const bool is_tg = batch.n_tokens > 0;

// limit the batch to avoid blocking the processing
if (is_tg) {
n_batch = 32; // TODO: configurable
}

// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
Expand All @@ -2426,6 +2434,18 @@ struct server_context {

// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
// count how many slots are currently processing prompt
int n_slots_pp = 0;
for (auto & slot : slots) {
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
n_slots_pp++;
}
}

// determine the chunk size of the chunk prefill
// a slot cannot submit more than this number of tokens in a single batch if other slots are processing
const int32_t n_chunk_pp = std::max(n_slots_pp > 0 ? (n_batch / n_slots_pp) : n_batch, 8);

for (auto & slot : slots) {
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
Expand Down Expand Up @@ -2609,8 +2629,10 @@ struct server_context {
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);

int n_cur = 0;

// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch && n_cur < n_chunk_pp) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);

if (slot.params.cache_prompt) {
Expand All @@ -2619,6 +2641,8 @@ struct server_context {

slot.n_prompt_tokens_processed++;
slot.n_past++;

n_cur++;
}

SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
Expand Down
Loading