Skip to content

Commit 5eb7358

Browse files
ngxsonarthw
authored andcommitted
server : bring back info of final chunk in stream mode (ggml-org#10722)
* server : bring back into to final chunk in stream mode * clarify a bit * traling space
1 parent e136690 commit 5eb7358

File tree

2 files changed

+94
-86
lines changed

2 files changed

+94
-86
lines changed

examples/server/server.cpp

+88-86
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ struct server_task_result {
392392
return false;
393393
}
394394
virtual bool is_stop() {
395-
// only used by server_task_result_cmpl_partial
395+
// only used by server_task_result_cmpl_*
396396
return false;
397397
}
398398
virtual int get_index() {
@@ -478,14 +478,20 @@ struct server_task_result_cmpl_final : server_task_result {
478478
return index;
479479
}
480480

481+
virtual bool is_stop() override {
482+
return true; // in stream mode, final responses are considered stop
483+
}
484+
481485
virtual json to_json() override {
482-
return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat();
486+
return oaicompat
487+
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
488+
: to_json_non_oaicompat();
483489
}
484490

485491
json to_json_non_oaicompat() {
486492
json res = json {
487493
{"index", index},
488-
{"content", content},
494+
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
489495
{"id_slot", id_slot},
490496
{"stop", true},
491497
{"model", oaicompat_model},
@@ -546,18 +552,46 @@ struct server_task_result_cmpl_final : server_task_result {
546552

547553
return res;
548554
}
555+
556+
json to_json_oaicompat_chat_stream() {
557+
std::time_t t = std::time(0);
558+
std::string finish_reason = "length";
559+
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
560+
finish_reason = "stop";
561+
}
562+
563+
json choices = json::array({json{{"finish_reason", finish_reason},
564+
{"index", 0},
565+
{"delta", json::object()}}});
566+
567+
json ret = json {
568+
{"choices", choices},
569+
{"created", t},
570+
{"id", oaicompat_cmpl_id},
571+
{"model", oaicompat_model},
572+
{"object", "chat.completion.chunk"},
573+
{"usage", json {
574+
{"completion_tokens", n_decoded},
575+
{"prompt_tokens", n_prompt_tokens},
576+
{"total_tokens", n_decoded + n_prompt_tokens},
577+
}},
578+
};
579+
580+
if (timings.prompt_n >= 0) {
581+
ret.push_back({"timings", timings.to_json()});
582+
}
583+
584+
return ret;
585+
}
549586
};
550587

551588
struct server_task_result_cmpl_partial : server_task_result {
552589
int index = 0;
553590
std::string content;
554591

555-
bool truncated;
556592
int32_t n_decoded;
557593
int32_t n_prompt_tokens;
558594

559-
stop_type stop = STOP_TYPE_NONE;
560-
561595
std::vector<completion_token_output> probs_output;
562596
result_timings timings;
563597

@@ -573,20 +607,19 @@ struct server_task_result_cmpl_partial : server_task_result {
573607
}
574608

575609
virtual bool is_stop() override {
576-
return stop != STOP_TYPE_NONE;
610+
return false; // in stream mode, partial responses are not considered stop
577611
}
578612

579613
virtual json to_json() override {
580-
if (oaicompat) {
581-
return to_json_oaicompat();
582-
}
583-
bool is_stop = stop != STOP_TYPE_NONE;
614+
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
615+
}
616+
617+
json to_json_non_oaicompat() {
584618
// non-OAI-compat JSON
585619
json res = json {
586620
{"index", index},
587621
{"content", content},
588-
{"stop_type", stop_type_to_str(stop)},
589-
{"stop", is_stop},
622+
{"stop", false},
590623
{"id_slot", id_slot},
591624
{"tokens_predicted", n_decoded},
592625
{"tokens_evaluated", n_prompt_tokens},
@@ -598,72 +631,54 @@ struct server_task_result_cmpl_partial : server_task_result {
598631
if (!probs_output.empty()) {
599632
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
600633
}
601-
if (is_stop) {
602-
res.push_back({"truncated", truncated});
603-
}
604634
return res;
605635
}
606636

607637
json to_json_oaicompat() {
608638
bool first = n_decoded == 0;
609-
610-
std::string finish_reason;
611-
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
612-
finish_reason = "stop";
613-
} else if (stop == STOP_TYPE_LIMIT) {
614-
finish_reason = "length";
615-
}
616-
617639
std::time_t t = std::time(0);
618-
619640
json choices;
620641

621-
if (!finish_reason.empty()) {
622-
choices = json::array({json{{"finish_reason", finish_reason},
623-
{"index", 0},
624-
{"delta", json::object()}}});
625-
} else {
626-
if (first) {
627-
if (content.empty()) {
628-
choices = json::array({json{{"finish_reason", nullptr},
629-
{"index", 0},
630-
{"delta", json{{"role", "assistant"}}}}});
631-
} else {
632-
// We have to send this as two updates to conform to openai behavior
633-
json initial_ret = json{{"choices", json::array({json{
634-
{"finish_reason", nullptr},
642+
if (first) {
643+
if (content.empty()) {
644+
choices = json::array({json{{"finish_reason", nullptr},
635645
{"index", 0},
636-
{"delta", json{
637-
{"role", "assistant"}
638-
}}}})},
639-
{"created", t},
640-
{"id", oaicompat_cmpl_id},
641-
{"model", oaicompat_model},
642-
{"object", "chat.completion.chunk"}};
643-
644-
json second_ret = json{
645-
{"choices", json::array({json{{"finish_reason", nullptr},
646-
{"index", 0},
647-
{"delta", json{
648-
{"content", content}}}
649-
}})},
650-
{"created", t},
651-
{"id", oaicompat_cmpl_id},
652-
{"model", oaicompat_model},
653-
{"object", "chat.completion.chunk"}};
654-
655-
return std::vector<json>({initial_ret, second_ret});
656-
}
646+
{"delta", json{{"role", "assistant"}}}}});
657647
} else {
658-
choices = json::array({json{
659-
{"finish_reason", nullptr},
660-
{"index", 0},
661-
{"delta",
662-
json{
663-
{"content", content},
664-
}},
665-
}});
648+
// We have to send this as two updates to conform to openai behavior
649+
json initial_ret = json{{"choices", json::array({json{
650+
{"finish_reason", nullptr},
651+
{"index", 0},
652+
{"delta", json{
653+
{"role", "assistant"}
654+
}}}})},
655+
{"created", t},
656+
{"id", oaicompat_cmpl_id},
657+
{"model", oaicompat_model},
658+
{"object", "chat.completion.chunk"}};
659+
660+
json second_ret = json{
661+
{"choices", json::array({json{{"finish_reason", nullptr},
662+
{"index", 0},
663+
{"delta", json{
664+
{"content", content}}}
665+
}})},
666+
{"created", t},
667+
{"id", oaicompat_cmpl_id},
668+
{"model", oaicompat_model},
669+
{"object", "chat.completion.chunk"}};
670+
671+
return std::vector<json>({initial_ret, second_ret});
666672
}
673+
} else {
674+
choices = json::array({json{
675+
{"finish_reason", nullptr},
676+
{"index", 0},
677+
{"delta",
678+
json{
679+
{"content", content},
680+
}},
681+
}});
667682
}
668683

669684
json ret = json {
@@ -678,14 +693,6 @@ struct server_task_result_cmpl_partial : server_task_result {
678693
ret.push_back({"timings", timings.to_json()});
679694
}
680695

681-
if (!finish_reason.empty()) {
682-
ret.push_back({"usage", json {
683-
{"completion_tokens", n_decoded},
684-
{"prompt_tokens", n_prompt_tokens},
685-
{"total_tokens", n_decoded + n_prompt_tokens},
686-
}});
687-
}
688-
689696
return std::vector<json>({ret});
690697
}
691698
};
@@ -1888,12 +1895,9 @@ struct server_context {
18881895
res->index = slot.index;
18891896
res->content = tkn.text_to_send;
18901897

1891-
res->truncated = slot.truncated;
18921898
res->n_decoded = slot.n_decoded;
18931899
res->n_prompt_tokens = slot.n_prompt_tokens;
18941900

1895-
res->stop = slot.stop;
1896-
18971901
res->verbose = slot.params.verbose;
18981902
res->oaicompat = slot.params.oaicompat;
18991903
res->oaicompat_chat = slot.params.oaicompat_chat;
@@ -1924,12 +1928,6 @@ struct server_context {
19241928
}
19251929

19261930
void send_final_response(server_slot & slot) {
1927-
if (slot.params.stream) {
1928-
// if in stream mode, send the last partial response
1929-
send_partial_response(slot, {0, "", {}});
1930-
return;
1931-
}
1932-
19331931
auto res = std::make_unique<server_task_result_cmpl_final>();
19341932
res->id = slot.id_task;
19351933
res->id_slot = slot.id;
@@ -1948,6 +1946,7 @@ struct server_context {
19481946
res->stop = slot.stop;
19491947

19501948
res->verbose = slot.params.verbose;
1949+
res->stream = slot.params.stream;
19511950
res->oaicompat = slot.params.oaicompat;
19521951
res->oaicompat_chat = slot.params.oaicompat_chat;
19531952
res->oaicompat_model = slot.params.oaicompat_model;
@@ -2100,7 +2099,10 @@ struct server_context {
21002099
return;
21012100
}
21022101

2103-
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr);
2102+
GGML_ASSERT(
2103+
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
2104+
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
2105+
);
21042106
if (!result_handler(result)) {
21052107
cancel_tasks(id_tasks);
21062108
break;

examples/server/tests/unit/test_completion.py

+6
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
4242
})
4343
content = ""
4444
for data in res:
45+
assert "stop" in data and type(data["stop"]) == bool
4546
if data["stop"]:
4647
assert data["timings"]["prompt_n"] == n_prompt
4748
assert data["timings"]["predicted_n"] == n_predicted
4849
assert data["truncated"] == truncated
50+
assert data["stop_type"] == "limit"
51+
assert "generation_settings" in data
52+
assert server.n_predict is not None
53+
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
54+
assert data["generation_settings"]["seed"] == server.seed
4955
assert match_regex(re_content, content)
5056
else:
5157
content += data["content"]

0 commit comments

Comments
 (0)