[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980)

Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
afeldman-nm
2025-02-24 11:29:41 -05:00
committed by GitHub
parent 444b0f0f62
commit befc402d34
5 changed files with 640 additions and 8 deletions

View File

@ -1,21 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Dict, List, Optional, Tuple
import pytest
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import LLM, SamplingParams
MODEL = "facebook/opt-125m"
DTYPE = "half"
def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
"""Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1")
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
return vllm_runner(
MODEL,
dtype=DTYPE,
max_model_len=128,
enforce_eager=True,
enable_prefix_caching=apc,
gpu_memory_utilization=0.5,
)
@pytest.fixture(
# Function scope decouples tests & allows
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
def vllm_model(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture parameterized by APC True/False."""
with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
yield vllm_model
@pytest.fixture(scope="function")
def vllm_model_apc(vllm_runner, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model:
yield vllm_model
def _get_test_sampling_params(
prompt_list: List[str],
seed: Optional[int] = 42,
) -> Tuple[List[SamplingParams], List[int]]:
"""Generate random sampling params for a batch."""
def get_mostly_n_gt1() -> int:
"""Mostly n \in [2,20], ~1/3 n=1"""
x = random.randint(0, 28)
if x < 10:
return 1
else:
return x - 8
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
# High temperature to maximize the chance of unique completions
return [
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
for n in n_list
], n_list
def test_parallel_sampling(vllm_model, example_prompts) -> None:
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
Args:
vllm_model: VllmRunner instance under test.
example_prompt: test fixture providing prompts for testing.
"""
sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
model: LLM = vllm_model.model
outputs = model.generate(example_prompts, sampling_params_list)
# Validate each request response
for out, n in zip(outputs, n_list):
completion_counts: Dict[str, int] = {}
# Assert correct number of completions
assert len(out.outputs) == n, (
f"{len(out.outputs)} completions; {n} expected.")
for idx in range(n):
comp = out.outputs[idx]
# Assert correct completion indices
assert comp.index == idx, (f"Index {comp.index}; expected {idx}.")
text = comp.text
completion_counts[text] = completion_counts.get(text, 0) + 1
# Assert unique completions
if len(completion_counts) != n:
repeats = {
txt: num
for (txt, num) in completion_counts.items() if num > 1
}
raise AssertionError(
f"{len(completion_counts)} unique completions; expected"
f" {n}. Repeats: {repeats}")
def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
model: LLM = vllm_model_apc.model
with pytest.raises(ValueError) as excinfo:
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
model.generate(
"Hello, my name is",
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))

View File

@ -250,6 +250,108 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
model_name: str):
"""Parallel sampling without streaming.
A single request output contains a list of completions.
"""
prompt = "What is an LLM?"
n = 3
max_tokens = 5
# High temperature to maximize chance of unique completions.
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
stream=False,
seed=42)
# Assert `n` completions
num_completions = len(completion.choices)
assert num_completions == n, (
f"Num completions {num_completions} but expected {n}.")
completion_repeats: Dict[str, int] = {}
for idx, choice in enumerate(completion.choices):
# Assert correct completion index & some finish reason.
assert choice.index == idx, (
f"Index {choice.index} but expected {idx}.")
assert choice.finish_reason is not None, (
"None finish_reason is invalid.")
text = choice.text
completion_repeats[text] = completion_repeats.get(text, 0) + 1
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:
repeats = {
txt: num
for (txt, num) in completion_repeats.items() if num > 1
}
raise AssertionError(
f"Expected {n} unique completions, got {num_unique};"
f" repeats: {repeats}.")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt = "What is an LLM?"
n = 3
max_tokens = 5
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
stream=True,
seed=42)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# Assert `n` completions with correct finish reasons
assert finish_reason_count == n, (
f"Expected {n} completions with valid indices and finish_reason.")
completion_repeats: Dict[str, int] = {}
for chunk in chunks:
chunk_len = len(chunk)
# Assert correct number of completion tokens
assert chunk_len == max_tokens, (
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
text = "".join(chunk)
completion_repeats[text] = completion_repeats.get(text, 0) + 1
print(text)
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:
repeats = {
txt: num
for (txt, num) in completion_repeats.items() if num > 1
}
raise AssertionError(f"{num_unique} unique completions, expected {n};"
f" repeats: {repeats}")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",