Skip to content

Commit ce8784b

Browse files
authored
server : fix format_infill (#10724)
* server : fix format_infill * fix * rename * update test * use another model * update test * update test * test_invalid_input_extra_req
1 parent e52522b commit ce8784b

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

examples/server/server.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -3484,6 +3484,11 @@ int main(int argc, char ** argv) {
34843484
json data = json::parse(req.body);
34853485

34863486
// validate input
3487+
if (data.contains("prompt") && !data.at("prompt").is_string()) {
3488+
// prompt is optional
3489+
res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3490+
}
3491+
34873492
if (!data.contains("input_prefix")) {
34883493
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
34893494
}
@@ -3493,9 +3498,11 @@ int main(int argc, char ** argv) {
34933498
}
34943499

34953500
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
3501+
// input_extra is optional
34963502
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
34973503
return;
34983504
}
3505+
34993506
json input_extra = json_value(data, "input_extra", json::array());
35003507
for (const auto & chunk : input_extra) {
35013508
// { "text": string, "filename": string }
@@ -3511,6 +3518,21 @@ int main(int argc, char ** argv) {
35113518
}
35123519
data["input_extra"] = input_extra; // default to empty array if it's not exist
35133520

3521+
std::string prompt = json_value(data, "prompt", std::string());
3522+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
3523+
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
3524+
data["prompt"] = format_infill(
3525+
ctx_server.ctx,
3526+
data.at("input_prefix"),
3527+
data.at("input_suffix"),
3528+
data.at("input_extra"),
3529+
ctx_server.params_base.n_batch,
3530+
ctx_server.params_base.n_predict,
3531+
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
3532+
ctx_server.params_base.spm_infill,
3533+
tokenized_prompts[0]
3534+
);
3535+
35143536
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
35153537
};
35163538

examples/server/tests/unit/test_infill.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,28 @@ def test_infill_without_input_extra():
1313
global server
1414
server.start()
1515
res = server.make_request("POST", "/infill", data={
16-
"prompt": "Complete this",
17-
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
16+
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
17+
"prompt": " int n_threads = llama_",
1818
"input_suffix": "}\n",
1919
})
2020
assert res.status_code == 200
21-
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
21+
assert match_regex("(Ann|small|shiny)+", res.body["content"])
2222

2323

2424
def test_infill_with_input_extra():
2525
global server
2626
server.start()
2727
res = server.make_request("POST", "/infill", data={
28-
"prompt": "Complete this",
2928
"input_extra": [{
3029
"filename": "llama.h",
3130
"text": "LLAMA_API int32_t llama_n_threads();\n"
3231
}],
33-
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
32+
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
33+
"prompt": " int n_threads = llama_",
3434
"input_suffix": "}\n",
3535
})
3636
assert res.status_code == 200
37-
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
37+
assert match_regex("(Dad|excited|park)+", res.body["content"])
3838

3939

4040
@pytest.mark.parametrize("input_extra", [
@@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra):
4848
global server
4949
server.start()
5050
res = server.make_request("POST", "/infill", data={
51-
"prompt": "Complete this",
5251
"input_extra": [input_extra],
53-
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
52+
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
53+
"prompt": " int n_threads = llama_",
5454
"input_suffix": "}\n",
5555
})
5656
assert res.status_code == 400
5757
assert "error" in res.body
58+
59+
60+
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
61+
def test_with_qwen_model():
62+
global server
63+
server.model_file = None
64+
server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
65+
server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
66+
server.start(timeout_seconds=600)
67+
res = server.make_request("POST", "/infill", data={
68+
"input_extra": [{
69+
"filename": "llama.h",
70+
"text": "LLAMA_API int32_t llama_n_threads();\n"
71+
}],
72+
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
73+
"prompt": " int n_threads = llama_",
74+
"input_suffix": "}\n",
75+
})
76+
assert res.status_code == 200
77+
assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"

examples/server/tests/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,6 @@ def match_regex(regex: str, text: str) -> bool:
371371
).search(text)
372372
is not None
373373
)
374+
375+
def is_slow_test_allowed():
376+
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"

0 commit comments

Comments
 (0)