13
13
from vllm .outputs import RequestOutput
14
14
from vllm .sampling_params import GuidedDecodingParams , SamplingParams
15
15
16
- GUIDED_DECODING_BACKENDS_V1 = ["xgrammar" ]
16
+ GUIDED_DECODING_BACKENDS_V1 = ["xgrammar" , "guidance" ]
17
17
MODELS_TO_TEST = [
18
18
"Qwen/Qwen2.5-1.5B-Instruct" , "mistralai/Ministral-8B-Instruct-2410"
19
19
]
@@ -30,12 +30,13 @@ def test_guided_json_completion(
30
30
model_name : str ,
31
31
):
32
32
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
33
- llm = LLM (model = model_name , max_model_len = 1024 )
34
- sampling_params = SamplingParams (temperature = 1.0 ,
35
- max_tokens = 1000 ,
36
- guided_decoding = GuidedDecodingParams (
37
- json = sample_json_schema ,
38
- backend = guided_decoding_backend ))
33
+ llm = LLM (model = model_name ,
34
+ max_model_len = 1024 ,
35
+ guided_decoding_backend = guided_decoding_backend )
36
+ sampling_params = SamplingParams (
37
+ temperature = 1.0 ,
38
+ max_tokens = 1000 ,
39
+ guided_decoding = GuidedDecodingParams (json = sample_json_schema ))
39
40
outputs = llm .generate (prompts = [
40
41
f"Give an example JSON for an employee profile "
41
42
f"that fits this schema: { sample_json_schema } "
@@ -111,13 +112,14 @@ def test_guided_json_object(
111
112
model_name : str ,
112
113
):
113
114
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
114
- llm = LLM (model = model_name , max_model_len = 1024 )
115
- sampling_params = SamplingParams (temperature = 1.0 ,
116
- max_tokens = 100 ,
117
- n = 2 ,
118
- guided_decoding = GuidedDecodingParams (
119
- json_object = True ,
120
- backend = guided_decoding_backend ))
115
+ llm = LLM (model = model_name ,
116
+ max_model_len = 1024 ,
117
+ guided_decoding_backend = guided_decoding_backend )
118
+ sampling_params = SamplingParams (
119
+ temperature = 1.0 ,
120
+ max_tokens = 100 ,
121
+ n = 2 ,
122
+ guided_decoding = GuidedDecodingParams (json_object = True ))
121
123
122
124
outputs = llm .generate (
123
125
prompts = ("Generate a JSON object with curly braces for a person with "
@@ -137,12 +139,20 @@ def test_guided_json_object(
137
139
138
140
# Parse to verify it is valid JSON
139
141
parsed_json = json .loads (generated_text )
140
- assert isinstance (parsed_json , dict )
142
+ allowed_types : tuple [type , ...] = (dict , )
143
+ if guided_decoding_backend == "xgrammar" :
144
+ # TODO - we are currently too permissive with xgrammar and
145
+ # allow # any valid json (typically comes back as a list or
146
+ # object). We can fix this by specifying a jsonschema of
147
+ # {"type": "object"}, # but we need this fix in a release
148
+ # first: https://github.com/mlc-ai/xgrammar/pull/264
149
+ allowed_types = (dict , list )
150
+ assert isinstance (parsed_json , allowed_types )
141
151
142
152
143
153
@pytest .mark .skip_global_cleanup
144
154
@pytest .mark .parametrize ("guided_decoding_backend" ,
145
- GUIDED_DECODING_BACKENDS_V1 )
155
+ GUIDED_DECODING_BACKENDS_V1 + [ "auto" ] )
146
156
@pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
147
157
def test_guided_json_unsupported_schema (
148
158
monkeypatch : pytest .MonkeyPatch ,
@@ -151,21 +161,43 @@ def test_guided_json_unsupported_schema(
151
161
model_name : str ,
152
162
):
153
163
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
154
- llm = LLM (model = model_name , max_model_len = 1024 )
155
- sampling_params = SamplingParams (temperature = 1.0 ,
156
- max_tokens = 1000 ,
157
- guided_decoding = GuidedDecodingParams (
158
- json = unsupported_json_schema ,
159
- backend = guided_decoding_backend ))
160
- with pytest .raises (ValueError ,
161
- match = "The provided JSON schema contains features "
162
- "not supported by xgrammar." ):
163
- llm .generate (prompts = [
164
- f"Give an example JSON for an employee profile "
165
- f"that fits this schema: { unsupported_json_schema } "
166
- ] * 2 ,
167
- sampling_params = sampling_params ,
168
- use_tqdm = True )
164
+ llm = LLM (model = model_name ,
165
+ max_model_len = 1024 ,
166
+ guided_decoding_backend = guided_decoding_backend )
167
+ sampling_params = SamplingParams (
168
+ temperature = 1.0 ,
169
+ max_tokens = 1000 ,
170
+ guided_decoding = GuidedDecodingParams (json = unsupported_json_schema ))
171
+ if guided_decoding_backend == "xgrammar" :
172
+ with pytest .raises (ValueError ,
173
+ match = "The provided JSON schema contains features "
174
+ "not supported by xgrammar." ):
175
+ llm .generate (prompts = [
176
+ f"Give an example JSON for an employee profile "
177
+ f"that fits this schema: { unsupported_json_schema } "
178
+ ] * 2 ,
179
+ sampling_params = sampling_params ,
180
+ use_tqdm = True )
181
+ else :
182
+ # This should work for both "guidance" and "auto".
183
+
184
+ outputs = llm .generate (
185
+ prompts = ("Give an example JSON object for a grade "
186
+ "that fits this schema: "
187
+ f"{ unsupported_json_schema } " ),
188
+ sampling_params = sampling_params ,
189
+ use_tqdm = True )
190
+ assert outputs is not None
191
+ for output in outputs :
192
+ assert output is not None
193
+ assert isinstance (output , RequestOutput )
194
+ generated_text = output .outputs [0 ].text
195
+ assert generated_text is not None
196
+ print (generated_text )
197
+
198
+ # Parse to verify it is valid JSON
199
+ parsed_json = json .loads (generated_text )
200
+ assert isinstance (parsed_json , dict )
169
201
170
202
171
203
@pytest .mark .skip_global_cleanup
@@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
179
211
model_name : str ,
180
212
):
181
213
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
182
- llm = LLM (model = model_name , max_model_len = 1024 )
183
- sampling_params = SamplingParams (temperature = 0.8 ,
184
- top_p = 0.95 ,
185
- max_tokens = 1000 ,
186
- guided_decoding = GuidedDecodingParams (
187
- grammar = sample_sql_ebnf ,
188
- backend = guided_decoding_backend ))
214
+ llm = LLM (model = model_name ,
215
+ max_model_len = 1024 ,
216
+ guided_decoding_backend = guided_decoding_backend )
217
+ sampling_params = SamplingParams (
218
+ temperature = 0.8 ,
219
+ top_p = 0.95 ,
220
+ max_tokens = 1000 ,
221
+ guided_decoding = GuidedDecodingParams (grammar = sample_sql_ebnf ))
189
222
outputs = llm .generate (
190
223
prompts = ("Generate a sql statement that selects col_1 from "
191
224
"table_1 where it is equal to 1" ),
@@ -222,13 +255,14 @@ def test_guided_grammar_lark(
222
255
model_name : str ,
223
256
):
224
257
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
225
- llm = LLM (model = model_name , max_model_len = 1024 )
226
- sampling_params = SamplingParams (temperature = 0.8 ,
227
- top_p = 0.95 ,
228
- max_tokens = 1000 ,
229
- guided_decoding = GuidedDecodingParams (
230
- grammar = sample_sql_lark ,
231
- backend = guided_decoding_backend ))
258
+ llm = LLM (model = model_name ,
259
+ max_model_len = 1024 ,
260
+ guided_decoding_backend = guided_decoding_backend )
261
+ sampling_params = SamplingParams (
262
+ temperature = 0.8 ,
263
+ top_p = 0.95 ,
264
+ max_tokens = 1000 ,
265
+ guided_decoding = GuidedDecodingParams (grammar = sample_sql_lark ))
232
266
outputs = llm .generate (
233
267
prompts = ("Generate a sql statement that selects col_1 from "
234
268
"table_1 where it is equal to 1" ),
@@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
269
303
model_name : str ,
270
304
):
271
305
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
272
- llm = LLM (model = model_name , max_model_len = 1024 )
273
- sampling_params = SamplingParams (temperature = 0.8 ,
274
- top_p = 0.95 ,
275
- max_tokens = 1000 ,
276
- guided_decoding = GuidedDecodingParams (
277
- grammar = "not a grammar" ,
278
- backend = guided_decoding_backend ))
279
- with pytest .raises (ValueError ,
280
- match = "Failed to convert the grammar "
281
- "from Lark to EBNF." ):
306
+ llm = LLM (model = model_name ,
307
+ max_model_len = 1024 ,
308
+ guided_decoding_backend = guided_decoding_backend )
309
+ sampling_params = SamplingParams (
310
+ temperature = 0.8 ,
311
+ top_p = 0.95 ,
312
+ max_tokens = 1000 ,
313
+ guided_decoding = GuidedDecodingParams (grammar = "not a grammar" ))
314
+ with pytest .raises (ValueError , match = "Failed to convert the grammar " ):
282
315
llm .generate (
283
316
prompts = ("Generate a sql statement that selects col_1 from "
284
317
"table_1 where it is equal to 1" ),
@@ -298,12 +331,13 @@ def test_guided_regex(
298
331
model_name : str ,
299
332
):
300
333
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
301
- llm = LLM (model = model_name , max_model_len = 1024 )
302
- sampling_params = SamplingParams (temperature = 0.8 ,
303
- top_p = 0.95 ,
304
- guided_decoding = GuidedDecodingParams (
305
- regex = sample_regex ,
306
- backend = guided_decoding_backend ))
334
+ llm = LLM (model = model_name ,
335
+ max_model_len = 1024 ,
336
+ guided_decoding_backend = guided_decoding_backend )
337
+ sampling_params = SamplingParams (
338
+ temperature = 0.8 ,
339
+ top_p = 0.95 ,
340
+ guided_decoding = GuidedDecodingParams (regex = sample_regex ))
307
341
outputs = llm .generate (
308
342
prompts = [
309
343
f"Give an example IPv4 address with this regex: { sample_regex } "
@@ -335,12 +369,13 @@ def test_guided_choice_completion(
335
369
model_name : str ,
336
370
):
337
371
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
338
- llm = LLM (model = model_name , max_model_len = 1024 )
339
- sampling_params = SamplingParams (temperature = 0.8 ,
340
- top_p = 0.95 ,
341
- guided_decoding = GuidedDecodingParams (
342
- choice = sample_guided_choice ,
343
- backend = guided_decoding_backend ))
372
+ llm = LLM (model = model_name ,
373
+ max_model_len = 1024 ,
374
+ guided_decoding_backend = guided_decoding_backend )
375
+ sampling_params = SamplingParams (
376
+ temperature = 0.8 ,
377
+ top_p = 0.95 ,
378
+ guided_decoding = GuidedDecodingParams (choice = sample_guided_choice ))
344
379
outputs = llm .generate (
345
380
prompts = "The best language for type-safe systems programming is " ,
346
381
sampling_params = sampling_params ,
0 commit comments