[V1] Add disable-any-whitespace option support for xgrammar (#15316)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant
2025-03-22 11:56:17 -04:00
committed by GitHub
parent 2f4bd358f1
commit eb63ea1e18
4 changed files with 53 additions and 5 deletions

View File

@ -57,6 +57,50 @@ def test_guided_json_completion(
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion_disable_any_whitespace(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
model_name: str,
):
if guided_decoding_backend != "xgrammar":
pytest.skip("disable-any-whitespace is only supported for xgrammar.")
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@ -301,7 +345,6 @@ def test_guided_choice_completion(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None

View File

@ -1486,7 +1486,9 @@ class EngineArgs:
return False
# Only support Xgrammar for guided decoding so far.
SUPPORTED_GUIDED_DECODING = ["xgrammar", "xgrammar:nofallback"]
SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend",
recommend_to_remove=False)

View File

@ -120,7 +120,7 @@ class Processor:
if not params.guided_decoding or not self.decoding_config:
return
supported_backends = ["xgrammar"]
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is "

View File

@ -26,6 +26,9 @@ class XgrammarBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
@ -74,8 +77,8 @@ class XgrammarBackend(StructuredOutputBackend):
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
ctx = self.compiler.compile_json_schema(
grammar_spec, any_whitespace=not self.disable_any_whitespace)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR: