From eb63ea1e185846b4d02333b61f73a02fe60a242e Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sat, 22 Mar 2025 11:56:17 -0400 Subject: [PATCH] [V1] Add `disable-any-whitespace` option support for xgrammar (#15316) Signed-off-by: Russell Bryant --- .../llm/test_struct_output_generate.py | 45 ++++++++++++++++++- vllm/engine/arg_utils.py | 4 +- vllm/v1/engine/processor.py | 2 +- vllm/v1/structured_output/backend_xgrammar.py | 7 ++- 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index b4eb475c23..d99ae59ddd 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -57,6 +57,50 @@ def test_guided_json_completion( jsonschema.validate(instance=output_json, schema=sample_json_schema) +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +@pytest.mark.parametrize("model_name", MODELS_TO_TEST) +def test_guided_json_completion_disable_any_whitespace( + monkeypatch: pytest.MonkeyPatch, + sample_json_schema: dict[str, Any], + guided_decoding_backend: str, + model_name: str, +): + if guided_decoding_backend != "xgrammar": + pytest.skip("disable-any-whitespace is only supported for xgrammar.") + guided_decoding_backend = 'xgrammar:disable-any-whitespace' + + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + assert "\n" not in generated_text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) @@ -301,7 +345,6 @@ def test_guided_choice_completion( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, use_tqdm=True) - assert outputs is not None for output in outputs: assert output is not None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e396e68f82..35c60a6026 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1486,7 +1486,9 @@ class EngineArgs: return False # Only support Xgrammar for guided decoding so far. - SUPPORTED_GUIDED_DECODING = ["xgrammar", "xgrammar:nofallback"] + SUPPORTED_GUIDED_DECODING = [ + "xgrammar", "xgrammar:disable-any-whitespace" + ] if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: _raise_or_fallback(feature_name="--guided-decoding-backend", recommend_to_remove=False) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 55e0fdcd65..8ba06336be 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -120,7 +120,7 @@ class Processor: if not params.guided_decoding or not self.decoding_config: return - supported_backends = ["xgrammar"] + supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"] engine_level_backend = self.decoding_config.guided_decoding_backend if engine_level_backend not in supported_backends: raise ValueError(f"Only {supported_backends} structured output is " diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index ce93ca5c75..9bfb644c58 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -26,6 +26,9 @@ class XgrammarBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config + self.disable_any_whitespace = ( + "disable-any-whitespace" + in vllm_config.decoding_config.guided_decoding_backend) tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, @@ -74,8 +77,8 @@ class XgrammarBackend(StructuredOutputBackend): def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: - ctx = self.compiler.compile_json_schema(grammar_spec, - any_whitespace=False) + ctx = self.compiler.compile_json_schema( + grammar_spec, any_whitespace=not self.disable_any_whitespace) elif request_type == StructuredOutputOptions.JSON_OBJECT: ctx = self.compiler.compile_builtin_json_grammar() elif request_type == StructuredOutputOptions.GRAMMAR: