[Speculators] Move tests + fix integration (#27308)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: rahul-tuli <rtuli@redhat.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@ -121,6 +121,86 @@ def test_ngram_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
|
||||
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||
],
|
||||
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
|
||||
)
|
||||
def test_speculators_model_integration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_path: str,
|
||||
):
|
||||
"""
|
||||
Test that speculators models work with the simplified integration.
|
||||
|
||||
This verifies the `vllm serve <speculator-model>` use case where
|
||||
speculative config is automatically detected from the model config
|
||||
without requiring explicit --speculative-config argument.
|
||||
|
||||
Tests:
|
||||
1. Speculator model is correctly detected
|
||||
2. Verifier model is extracted from speculator config
|
||||
3. Speculative decoding is automatically enabled
|
||||
4. Text generation works correctly
|
||||
5. Output matches reference (non-speculative) generation
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Generate test prompts
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
# First run: Direct speculator model (simplified integration)
|
||||
spec_llm = LLM(model=model_path, max_model_len=1024)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
|
||||
# Verify speculative config was auto-detected
|
||||
assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
|
||||
f"Speculative config should be auto-detected for {model_path}"
|
||||
)
|
||||
|
||||
spec_config = spec_llm.llm_engine.vllm_config.speculative_config
|
||||
assert spec_config.num_speculative_tokens > 0, (
|
||||
f"Expected positive speculative tokens, "
|
||||
f"got {spec_config.num_speculative_tokens}"
|
||||
)
|
||||
|
||||
# Verify draft model is set to the speculator model
|
||||
assert spec_config.model == model_path, (
|
||||
f"Draft model should be {model_path}, got {spec_config.model}"
|
||||
)
|
||||
|
||||
# Extract verifier model for reference run
|
||||
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
|
||||
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Second run: Reference without speculative decoding
|
||||
ref_llm = LLM(model=verifier_model, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Compare outputs
|
||||
matches = sum(
|
||||
1
|
||||
for ref, spec in zip(ref_outputs, spec_outputs)
|
||||
if ref.outputs[0].text == spec.outputs[0].text
|
||||
)
|
||||
|
||||
# Heuristic: expect at least 66% of prompts to match exactly
|
||||
assert matches >= int(0.66 * len(ref_outputs)), (
|
||||
f"Only {matches}/{len(ref_outputs)} outputs matched. "
|
||||
f"Expected at least {int(0.66 * len(ref_outputs))} matches."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
|
||||
@ -22,10 +22,6 @@ from vllm.model_executor.models.interfaces import supports_eagle3
|
||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
||||
id="qwen3-eagle3-speculator-w4a16-verifier",
|
||||
),
|
||||
pytest.param(
|
||||
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
|
||||
id="llama3-eagl3-multiple-layers",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eagle3_speculators_model(
|
||||
Reference in New Issue
Block a user