Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: add OpenAI compatible response format for legacy /completions with b… #10645

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 77 additions & 74 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ struct server_context {
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
slot.params.sampling.n_probs = json_value(data, "n_probs", json_value(data, "logprobs", defaults.sampling.n_probs));
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);

slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
Expand Down Expand Up @@ -1340,7 +1340,8 @@ struct server_context {
}
slot.n_sent_token_probs = probs_stop_pos;

res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
// TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output, true);
}

if (slot.oaicompat) {
Expand Down Expand Up @@ -1379,7 +1380,7 @@ struct server_context {
{"timings", slot.get_formated_timings()},
{"index", slot.index},
};

if (slot.params.sampling.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
Expand All @@ -1395,7 +1396,8 @@ struct server_context {
slot.generated_token_probs.end());
}

res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
// TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs, true);
}

if (slot.oaicompat) {
Expand Down Expand Up @@ -2901,46 +2903,99 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }});
};

const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
const auto handle_completions_generic = [&ctx_server, &params, &res_error, &res_ok, verbose](server_task_inf_type inf_type, json & data, httplib::Response & res, bool is_chat = false) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}

// Parse data for /chat/completions format if needed
if (is_chat) {
data = oaicompat_completion_params_parse(ctx_server.model, data, params.chat_template);
}

std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);

bool stream = json_value(data, "stream", false);
bool oai_compat = json_value(data, "oai_compat", true);
const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();

if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
if (results.size() == 1) {
if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) {
if (is_chat) {
// multitask is never supported in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id,
/*.streaming =*/ false, verbose, /*.legacy_format =*/ !is_chat);
res_ok(res, result_oai);
} else {
if (results.size() == 1) {
// single result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id,
/*.streaming =*/ false, verbose, /*.legacy_format =*/ true);
res_ok(res, result_oai);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & result : results) {
arr.push_back(format_final_response_oaicompat(data, result.data, completion_id,
/*.streaming =*/ false, verbose, /*.legacy_format =*/ true));
}
res_ok(res, arr);
}
}
}
else{
if (results.size() == 1) {
// single result
res_ok(res, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
res_ok(res, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
}
res_ok(res, arr);
}
res_ok(res, arr);
}
}, [&](const json & error_data) {
res_error(res, error_data);
});

ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, is_chat, inf_type, oai_compat](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
return server_sent_event(sink, "data", result.data);
if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) {

std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, !is_chat);
for (auto & event_data : result_array) {
if (event_data.empty()) {
continue; // skip the stop token
}
if (!server_sent_event(sink, "data", event_data)) {
return false; // connection is closed
}
}
return true; // ok

}
return server_sent_event(sink, "data", result.data);

}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);

});

if (is_chat) {
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
}
sink.done();
return false;
return true;
};

auto on_complete = [task_ids, &ctx_server] (bool) {
Expand All @@ -2953,7 +3008,12 @@ int main(int argc, char ** argv) {

const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, false);
};

const auto handle_chat_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
};

const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
Expand Down Expand Up @@ -3006,63 +3066,6 @@ int main(int argc, char ** argv) {
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
};

// TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}

json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);

std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);

bool stream = json_value(data, "stream", false);
const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();

if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
res_ok(res, result_oai);
}, [&](const json & error_data) {
res_error(res, error_data);
});

ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto & event_data : result_array) {
if (event_data.empty()) {
continue; // skip the stop token
}
if (!server_sent_event(sink, "data", event_data)) {
return false; // connection is closed
}
}
return true; // ok
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
sink.done();
return true;
};

auto on_complete = [task_ids, &ctx_server] (bool) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
};

res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};

const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
json models = {
{"object", "list"},
Expand Down
14 changes: 12 additions & 2 deletions examples/server/tests/unit/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,19 @@ def test_load_split_model():
server.model_alias = "tinyllama-split"
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 16,
"max_tokens": 16,
"prompt": "Hello",
"temperature": 0.0,
})
assert res.status_code == 200
assert match_regex("(little|girl)+", res.body["content"])
# Verify response structure
assert "id" in res.body
assert "object" in res.body
assert "created" in res.body
assert "model" in res.body
assert "choices" in res.body
assert isinstance(res.body["choices"], list)
assert len(res.body["choices"]) > 0
assert "text" in res.body["choices"][0]
# Verify the actual content
assert match_regex("(little|girl)+", res.body["choices"][0]["text"])
Loading
Loading