Skip to content

Commit a09ad90

Browse files
russellbJC1DAmmoskal
authored
[V1] guidance backend for structured output + auto fallback mode (#14779)
Signed-off-by: Russell Bryant <[email protected]> Co-authored-by: Loc Huynh <[email protected]> Co-authored-by: Michal Moskal <[email protected]>
1 parent 10b34e3 commit a09ad90

File tree

9 files changed

+345
-111
lines changed

9 files changed

+345
-111
lines changed

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pillow # Required for image processing
1818
prometheus-fastapi-instrumentator >= 7.0.0
1919
tiktoken >= 0.6.0 # Required for DBRX tokenizer
2020
lm-format-enforcer >= 0.10.11, < 0.11
21-
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
21+
llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
2222
outlines == 0.1.11
2323
lark == 1.2.2
2424
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 102 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.outputs import RequestOutput
1414
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1515

16-
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
16+
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
1717
MODELS_TO_TEST = [
1818
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
1919
]
@@ -30,12 +30,13 @@ def test_guided_json_completion(
3030
model_name: str,
3131
):
3232
monkeypatch.setenv("VLLM_USE_V1", "1")
33-
llm = LLM(model=model_name, max_model_len=1024)
34-
sampling_params = SamplingParams(temperature=1.0,
35-
max_tokens=1000,
36-
guided_decoding=GuidedDecodingParams(
37-
json=sample_json_schema,
38-
backend=guided_decoding_backend))
33+
llm = LLM(model=model_name,
34+
max_model_len=1024,
35+
guided_decoding_backend=guided_decoding_backend)
36+
sampling_params = SamplingParams(
37+
temperature=1.0,
38+
max_tokens=1000,
39+
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
3940
outputs = llm.generate(prompts=[
4041
f"Give an example JSON for an employee profile "
4142
f"that fits this schema: {sample_json_schema}"
@@ -111,13 +112,14 @@ def test_guided_json_object(
111112
model_name: str,
112113
):
113114
monkeypatch.setenv("VLLM_USE_V1", "1")
114-
llm = LLM(model=model_name, max_model_len=1024)
115-
sampling_params = SamplingParams(temperature=1.0,
116-
max_tokens=100,
117-
n=2,
118-
guided_decoding=GuidedDecodingParams(
119-
json_object=True,
120-
backend=guided_decoding_backend))
115+
llm = LLM(model=model_name,
116+
max_model_len=1024,
117+
guided_decoding_backend=guided_decoding_backend)
118+
sampling_params = SamplingParams(
119+
temperature=1.0,
120+
max_tokens=100,
121+
n=2,
122+
guided_decoding=GuidedDecodingParams(json_object=True))
121123

122124
outputs = llm.generate(
123125
prompts=("Generate a JSON object with curly braces for a person with "
@@ -137,12 +139,20 @@ def test_guided_json_object(
137139

138140
# Parse to verify it is valid JSON
139141
parsed_json = json.loads(generated_text)
140-
assert isinstance(parsed_json, dict)
142+
allowed_types: tuple[type, ...] = (dict, )
143+
if guided_decoding_backend == "xgrammar":
144+
# TODO - we are currently too permissive with xgrammar and
145+
# allow # any valid json (typically comes back as a list or
146+
# object). We can fix this by specifying a jsonschema of
147+
# {"type": "object"}, # but we need this fix in a release
148+
# first: https://github.com/mlc-ai/xgrammar/pull/264
149+
allowed_types = (dict, list)
150+
assert isinstance(parsed_json, allowed_types)
141151

142152

143153
@pytest.mark.skip_global_cleanup
144154
@pytest.mark.parametrize("guided_decoding_backend",
145-
GUIDED_DECODING_BACKENDS_V1)
155+
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
146156
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
147157
def test_guided_json_unsupported_schema(
148158
monkeypatch: pytest.MonkeyPatch,
@@ -151,21 +161,43 @@ def test_guided_json_unsupported_schema(
151161
model_name: str,
152162
):
153163
monkeypatch.setenv("VLLM_USE_V1", "1")
154-
llm = LLM(model=model_name, max_model_len=1024)
155-
sampling_params = SamplingParams(temperature=1.0,
156-
max_tokens=1000,
157-
guided_decoding=GuidedDecodingParams(
158-
json=unsupported_json_schema,
159-
backend=guided_decoding_backend))
160-
with pytest.raises(ValueError,
161-
match="The provided JSON schema contains features "
162-
"not supported by xgrammar."):
163-
llm.generate(prompts=[
164-
f"Give an example JSON for an employee profile "
165-
f"that fits this schema: {unsupported_json_schema}"
166-
] * 2,
167-
sampling_params=sampling_params,
168-
use_tqdm=True)
164+
llm = LLM(model=model_name,
165+
max_model_len=1024,
166+
guided_decoding_backend=guided_decoding_backend)
167+
sampling_params = SamplingParams(
168+
temperature=1.0,
169+
max_tokens=1000,
170+
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
171+
if guided_decoding_backend == "xgrammar":
172+
with pytest.raises(ValueError,
173+
match="The provided JSON schema contains features "
174+
"not supported by xgrammar."):
175+
llm.generate(prompts=[
176+
f"Give an example JSON for an employee profile "
177+
f"that fits this schema: {unsupported_json_schema}"
178+
] * 2,
179+
sampling_params=sampling_params,
180+
use_tqdm=True)
181+
else:
182+
# This should work for both "guidance" and "auto".
183+
184+
outputs = llm.generate(
185+
prompts=("Give an example JSON object for a grade "
186+
"that fits this schema: "
187+
f"{unsupported_json_schema}"),
188+
sampling_params=sampling_params,
189+
use_tqdm=True)
190+
assert outputs is not None
191+
for output in outputs:
192+
assert output is not None
193+
assert isinstance(output, RequestOutput)
194+
generated_text = output.outputs[0].text
195+
assert generated_text is not None
196+
print(generated_text)
197+
198+
# Parse to verify it is valid JSON
199+
parsed_json = json.loads(generated_text)
200+
assert isinstance(parsed_json, dict)
169201

170202

171203
@pytest.mark.skip_global_cleanup
@@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
179211
model_name: str,
180212
):
181213
monkeypatch.setenv("VLLM_USE_V1", "1")
182-
llm = LLM(model=model_name, max_model_len=1024)
183-
sampling_params = SamplingParams(temperature=0.8,
184-
top_p=0.95,
185-
max_tokens=1000,
186-
guided_decoding=GuidedDecodingParams(
187-
grammar=sample_sql_ebnf,
188-
backend=guided_decoding_backend))
214+
llm = LLM(model=model_name,
215+
max_model_len=1024,
216+
guided_decoding_backend=guided_decoding_backend)
217+
sampling_params = SamplingParams(
218+
temperature=0.8,
219+
top_p=0.95,
220+
max_tokens=1000,
221+
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
189222
outputs = llm.generate(
190223
prompts=("Generate a sql statement that selects col_1 from "
191224
"table_1 where it is equal to 1"),
@@ -222,13 +255,14 @@ def test_guided_grammar_lark(
222255
model_name: str,
223256
):
224257
monkeypatch.setenv("VLLM_USE_V1", "1")
225-
llm = LLM(model=model_name, max_model_len=1024)
226-
sampling_params = SamplingParams(temperature=0.8,
227-
top_p=0.95,
228-
max_tokens=1000,
229-
guided_decoding=GuidedDecodingParams(
230-
grammar=sample_sql_lark,
231-
backend=guided_decoding_backend))
258+
llm = LLM(model=model_name,
259+
max_model_len=1024,
260+
guided_decoding_backend=guided_decoding_backend)
261+
sampling_params = SamplingParams(
262+
temperature=0.8,
263+
top_p=0.95,
264+
max_tokens=1000,
265+
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
232266
outputs = llm.generate(
233267
prompts=("Generate a sql statement that selects col_1 from "
234268
"table_1 where it is equal to 1"),
@@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
269303
model_name: str,
270304
):
271305
monkeypatch.setenv("VLLM_USE_V1", "1")
272-
llm = LLM(model=model_name, max_model_len=1024)
273-
sampling_params = SamplingParams(temperature=0.8,
274-
top_p=0.95,
275-
max_tokens=1000,
276-
guided_decoding=GuidedDecodingParams(
277-
grammar="not a grammar",
278-
backend=guided_decoding_backend))
279-
with pytest.raises(ValueError,
280-
match="Failed to convert the grammar "
281-
"from Lark to EBNF."):
306+
llm = LLM(model=model_name,
307+
max_model_len=1024,
308+
guided_decoding_backend=guided_decoding_backend)
309+
sampling_params = SamplingParams(
310+
temperature=0.8,
311+
top_p=0.95,
312+
max_tokens=1000,
313+
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
314+
with pytest.raises(ValueError, match="Failed to convert the grammar "):
282315
llm.generate(
283316
prompts=("Generate a sql statement that selects col_1 from "
284317
"table_1 where it is equal to 1"),
@@ -298,12 +331,13 @@ def test_guided_regex(
298331
model_name: str,
299332
):
300333
monkeypatch.setenv("VLLM_USE_V1", "1")
301-
llm = LLM(model=model_name, max_model_len=1024)
302-
sampling_params = SamplingParams(temperature=0.8,
303-
top_p=0.95,
304-
guided_decoding=GuidedDecodingParams(
305-
regex=sample_regex,
306-
backend=guided_decoding_backend))
334+
llm = LLM(model=model_name,
335+
max_model_len=1024,
336+
guided_decoding_backend=guided_decoding_backend)
337+
sampling_params = SamplingParams(
338+
temperature=0.8,
339+
top_p=0.95,
340+
guided_decoding=GuidedDecodingParams(regex=sample_regex))
307341
outputs = llm.generate(
308342
prompts=[
309343
f"Give an example IPv4 address with this regex: {sample_regex}"
@@ -335,12 +369,13 @@ def test_guided_choice_completion(
335369
model_name: str,
336370
):
337371
monkeypatch.setenv("VLLM_USE_V1", "1")
338-
llm = LLM(model=model_name, max_model_len=1024)
339-
sampling_params = SamplingParams(temperature=0.8,
340-
top_p=0.95,
341-
guided_decoding=GuidedDecodingParams(
342-
choice=sample_guided_choice,
343-
backend=guided_decoding_backend))
372+
llm = LLM(model=model_name,
373+
max_model_len=1024,
374+
guided_decoding_backend=guided_decoding_backend)
375+
sampling_params = SamplingParams(
376+
temperature=0.8,
377+
top_p=0.95,
378+
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
344379
outputs = llm.generate(
345380
prompts="The best language for type-safe systems programming is ",
346381
sampling_params=sampling_params,

vllm/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,12 +2800,17 @@ def compute_hash(self) -> str:
28002800
return hash_str
28012801

28022802
def __post_init__(self):
2803-
valid_guided_backends = [
2804-
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
2803+
v0_valid_guided_backends = [
2804+
'outlines', 'lm-format-enforcer', 'xgrammar'
28052805
]
2806+
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
28062807

28072808
backend = GuidedDecodingParams(
28082809
backend=self.guided_decoding_backend).backend_name
2810+
if envs.VLLM_USE_V1:
2811+
valid_guided_backends = v1_valid_guided_backends
2812+
else:
2813+
valid_guided_backends = v0_valid_guided_backends
28092814
if backend not in valid_guided_backends:
28102815
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
28112816
f" must be one of {valid_guided_backends}")

vllm/engine/arg_utils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,16 +391,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
391391
default='xgrammar',
392392
help='Which engine will be used for guided decoding'
393393
' (JSON schema / regex etc) by default. Currently support '
394-
'https://github.com/outlines-dev/outlines, '
395-
'https://github.com/mlc-ai/xgrammar, and '
396-
'https://github.com/noamgat/lm-format-enforcer.'
397-
' Can be overridden per request via guided_decoding_backend'
398-
' parameter.\n'
399-
'Backend-specific options can be supplied in a comma-separated '
400-
'list following a colon after the backend name. Valid backends and '
401-
'all available options are: [xgrammar:no-fallback, '
402-
'xgrammar:disable-any-whitespace, '
403-
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
394+
'https://github.com/mlc-ai/xgrammar and '
395+
'https://github.com/guidance-ai/llguidance.'
396+
'Valid backend values are "xgrammar", "guidance", and "auto". '
397+
'With "auto", we will make opinionated choices based on request'
398+
'contents and what the backend libraries currently support, so '
399+
'the behavior is subject to change in each release. '
400+
'The default is xgrammar.')
404401
parser.add_argument(
405402
'--logits-processor-pattern',
406403
type=nullable_str,
@@ -1539,9 +1536,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
15391536
recommend_to_remove=False)
15401537
return False
15411538

1542-
# Only support Xgrammar for guided decoding so far.
1539+
# Xgrammar and Guidance are supported.
15431540
SUPPORTED_GUIDED_DECODING = [
1544-
"xgrammar", "xgrammar:disable-any-whitespace"
1541+
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
15451542
]
15461543
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
15471544
_raise_or_fallback(feature_name="--guided-decoding-backend",

vllm/v1/engine/processor.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections.abc import Mapping
55
from typing import Optional, Union
66

7-
import vllm.platforms
87
from vllm.config import VllmConfig
98
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
109
PromptType, SingletonInputsAdapter)
@@ -20,7 +19,10 @@
2019
from vllm.sampling_params import SamplingParams
2120
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
2221
from vllm.v1.engine import EngineCoreRequest
23-
from vllm.v1.structured_output.utils import validate_structured_output_request
22+
from vllm.v1.structured_output.backend_guidance import (
23+
validate_guidance_grammar)
24+
from vllm.v1.structured_output.utils import (
25+
validate_structured_output_request_xgrammar)
2426

2527

2628
class Processor:
@@ -120,7 +122,9 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
120122
if not params.guided_decoding or not self.decoding_config:
121123
return
122124

123-
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
125+
supported_backends = [
126+
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
127+
]
124128
engine_level_backend = self.decoding_config.guided_decoding_backend
125129
if engine_level_backend not in supported_backends:
126130
raise ValueError(f"Only {supported_backends} structured output is "
@@ -134,10 +138,31 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
134138
else:
135139
params.guided_decoding.backend = engine_level_backend
136140

137-
if vllm.platforms.current_platform.is_tpu():
138-
raise ValueError("Structured output is not supported on TPU.")
139-
140-
validate_structured_output_request(params)
141+
# Request content validation
142+
143+
if engine_level_backend == "xgrammar":
144+
# xgrammar with no fallback
145+
validate_structured_output_request_xgrammar(params)
146+
params.guided_decoding.backend = "xgrammar"
147+
elif engine_level_backend == "auto":
148+
# "auto" is an opt-in to opinionated behavior where we try to
149+
# choose a backend based on request contents. This is not the
150+
# default as it is less predictable and subject to change
151+
# between releases as feature support changes.
152+
try:
153+
validate_structured_output_request_xgrammar(params)
154+
params.guided_decoding.backend = "xgrammar"
155+
except ValueError:
156+
# The request includes some jsonschema feature(s) that
157+
# are not supported in xgrammar. Fall back to guidance.
158+
params.guided_decoding.backend = "guidance"
159+
160+
if params.guided_decoding.backend == "guidance":
161+
# TODO ideally we would have the LLTokenizer here as Lark syntax
162+
# allows <|special_token|> and similar, see
163+
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
164+
# Without tokenizer these are disallowed in grammars.
165+
validate_guidance_grammar(params, tokenizer=None)
141166

142167
def process_inputs(
143168
self,

vllm/v1/structured_output/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.config import VllmConfig
99
from vllm.logger import init_logger
10+
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
1011
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
1112
StructuredOutputGrammar)
1213

@@ -50,6 +51,8 @@ def grammar_init(self, request: Request) -> None:
5051
XgrammarBackend)
5152

5253
self.backend = XgrammarBackend(self.vllm_config)
54+
elif backend_name == "guidance":
55+
self.backend = GuidanceBackend(self.vllm_config)
5356
else:
5457
raise ValueError(
5558
f"Unsupported structured output backend: {backend_name}")

0 commit comments

Comments
 (0)