@@ -13,28 +13,28 @@ def test_infill_without_input_extra():
13
13
global server
14
14
server .start ()
15
15
res = server .make_request ("POST" , "/infill" , data = {
16
- "prompt " : "Complete this " ,
17
- "input_prefix " : "#include <cstdio> \n #include \" llama.h \" \n \n int main() { \n int n_threads = llama_" ,
16
+ "input_prefix " : "#include <cstdio> \n #include \" llama.h \" \n \n int main() { \n " ,
17
+ "prompt " : " int n_threads = llama_" ,
18
18
"input_suffix" : "}\n " ,
19
19
})
20
20
assert res .status_code == 200
21
- assert match_regex ("(One|day|she|saw|big|scary|bird )+" , res .body ["content" ])
21
+ assert match_regex ("(Ann|small|shiny )+" , res .body ["content" ])
22
22
23
23
24
24
def test_infill_with_input_extra ():
25
25
global server
26
26
server .start ()
27
27
res = server .make_request ("POST" , "/infill" , data = {
28
- "prompt" : "Complete this" ,
29
28
"input_extra" : [{
30
29
"filename" : "llama.h" ,
31
30
"text" : "LLAMA_API int32_t llama_n_threads();\n "
32
31
}],
33
- "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n int n_threads = llama_" ,
32
+ "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n " ,
33
+ "prompt" : " int n_threads = llama_" ,
34
34
"input_suffix" : "}\n " ,
35
35
})
36
36
assert res .status_code == 200
37
- assert match_regex ("(cuts|Jimmy|mom|came|into|the|room )+" , res .body ["content" ])
37
+ assert match_regex ("(Dad|excited|park )+" , res .body ["content" ])
38
38
39
39
40
40
@pytest .mark .parametrize ("input_extra" , [
@@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra):
48
48
global server
49
49
server .start ()
50
50
res = server .make_request ("POST" , "/infill" , data = {
51
- "prompt" : "Complete this" ,
52
51
"input_extra" : [input_extra ],
53
- "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n int n_threads = llama_" ,
52
+ "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n " ,
53
+ "prompt" : " int n_threads = llama_" ,
54
54
"input_suffix" : "}\n " ,
55
55
})
56
56
assert res .status_code == 400
57
57
assert "error" in res .body
58
+
59
+
60
+ @pytest .mark .skipif (not is_slow_test_allowed (), reason = "skipping slow test" )
61
+ def test_with_qwen_model ():
62
+ global server
63
+ server .model_file = None
64
+ server .model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
65
+ server .model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
66
+ server .start (timeout_seconds = 600 )
67
+ res = server .make_request ("POST" , "/infill" , data = {
68
+ "input_extra" : [{
69
+ "filename" : "llama.h" ,
70
+ "text" : "LLAMA_API int32_t llama_n_threads();\n "
71
+ }],
72
+ "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n " ,
73
+ "prompt" : " int n_threads = llama_" ,
74
+ "input_suffix" : "}\n " ,
75
+ })
76
+ assert res .status_code == 200
77
+ assert res .body ["content" ] == "n_threads();\n printf(\" Number of threads: %d\\ n\" , n_threads);\n return 0;\n "
0 commit comments