[Tests] conftest: Extending VllmRunner and HfRunner to accept token_ids as input (#26295)

Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yannick Schnider
2025-10-06 19:19:34 +02:00
committed by GitHub
parent 4727a8afa7
commit 6431be808f
2 changed files with 63 additions and 63 deletions

View File

@ -57,7 +57,7 @@ from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import set_default_torch_num_threads
from vllm.utils import is_list_of, set_default_torch_num_threads
logger = init_logger(__name__)
@ -406,11 +406,11 @@ class HfRunner:
def get_inputs(
self,
prompts: list[str],
prompts: Union[list[str], list[list[int]]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[Union[BatchFeature, BatchEncoding]]:
) -> list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]]:
if images is not None:
assert len(prompts) == len(images)
@ -420,31 +420,48 @@ class HfRunner:
if audios is not None:
assert len(prompts) == len(audios)
all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
all_inputs: list[
Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]
] = []
for i, prompt in enumerate(prompts):
processor_kwargs: dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and (image := images[i]) is not None:
processor_kwargs["images"] = image
if videos is not None and (video := videos[i]) is not None:
processor_kwargs["videos"] = video
if audios is not None and (audio_inputs := audios[i]) is not None:
# HACK - not all processors take sampling_rate; we should
# clean this up in the future.
if len(audio_inputs) == 2:
audio, sr = audio_inputs
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
else:
processor_kwargs["audio"] = audio_inputs
if isinstance(prompt, str):
processor_kwargs: dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and (image := images[i]) is not None:
processor_kwargs["images"] = image
if videos is not None and (video := videos[i]) is not None:
processor_kwargs["videos"] = video
if audios is not None and (audio_inputs := audios[i]) is not None:
# HACK - not all processors take sampling_rate; we should
# clean this up in the future.
if len(audio_inputs) == 2:
audio, sr = audio_inputs
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
else:
processor_kwargs["audio"] = audio_inputs
inputs = self.processor(**processor_kwargs)
if isinstance(inputs, BatchFeature):
inputs = inputs.to(dtype=self.dtype)
all_inputs.append(inputs)
inputs = self.processor(**processor_kwargs)
if isinstance(inputs, BatchFeature):
inputs = inputs.to(dtype=self.dtype)
all_inputs.append(inputs)
else:
# check that prompt is (batched) list of integers (token ids)
if not is_list_of(prompt, typ=int, check="all"):
raise ValueError(
"Prompt must be a list of ints corresponding to the prompt token ids."
)
# check that no multimodal input is provided
if images or videos or audios:
raise ValueError(
"When providing prompt token ids multimodal inputs are not supported."
)
input_dict = {
"input_ids": torch.tensor(prompt, dtype=torch.long).unsqueeze(0),
}
all_inputs.append(input_dict)
return all_inputs
@ -477,7 +494,7 @@ class HfRunner:
def generate(
self,
prompts: list[str],
prompts: Union[list[str], list[list[int]]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
@ -505,7 +522,7 @@ class HfRunner:
def generate_greedy(
self,
prompts: list[str],
prompts: Union[list[str], list[list[int]]],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
@ -807,7 +824,7 @@ class VllmRunner:
def generate(
self,
prompts: Union[list[str], list[torch.Tensor]],
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
@ -877,7 +894,7 @@ class VllmRunner:
def generate_greedy(
self,
prompts: Union[list[str], list[torch.Tensor]],
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,

View File

@ -23,13 +23,10 @@ the 1st will be sampled after the prefill and the 2nd after the first decode
"""
import pytest
import torch
from transformers import AutoModelForCausalLM
from tests.conftest import HfRunner, VllmRunner
from tests.models.utils import check_outputs_equal
from tests.utils import create_new_process_for_each_test
from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt
@create_new_process_for_each_test()
@ -43,6 +40,8 @@ from vllm.inputs import TokensPrompt
)
def test_max_context_length(
model: str,
vllm_runner: type[VllmRunner],
hf_runner: type[HfRunner],
prompt_len: int,
max_tokens: int,
) -> None:
@ -57,42 +56,26 @@ def test_max_context_length(
# Construct a prompt of size prompt_len
prompt_ids = [[43] * prompt_len]
# Generate max_tokens new tokens deterministically.
sampling_params = [
SamplingParams(max_tokens=max_tokens, temperature=0.0, ignore_eos=True)
]
# --- vLLM generation ---
llm = LLM(
model=model,
tokenizer=model,
with vllm_runner(
model_name=model,
tokenizer_name=model,
max_model_len=2048,
max_num_seqs=1,
tensor_parallel_size=1,
)
vllm_token_prompts = [TokensPrompt(prompt_token_ids=prompt_ids[0])]
vllm_results = llm.generate(vllm_token_prompts, sampling_params)
vllm_output_ids = vllm_results[0].outputs[0].token_ids
) as vllm_model:
# Generate max_tokens new tokens deterministically.
vllm_outputs = vllm_model.generate_greedy(prompt_ids, max_tokens)
# --- HuggingFace generation ---
with torch.no_grad():
hf_model = AutoModelForCausalLM.from_pretrained(model)
with hf_runner(
model_name=model,
) as hf_model:
hf_outputs = hf_model.generate_greedy(prompt_ids, max_tokens)
# HF expects a tensor of input ids shaped (batch, seq_len).
hf_input_tokens = torch.tensor(prompt_ids[0]).unsqueeze(0)
# Generate max_tokens new tokens deterministically.
hf_generated = hf_model.generate(
hf_input_tokens,
do_sample=False,
min_new_tokens=max_tokens,
max_new_tokens=max_tokens,
)
# HF returns the prompt + generated tokens. Slice off the prompt.
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]) :]
# vLLM and HF runners return prompt + generated tokens. Slice off the prompt.
vllm_output_ids = vllm_outputs[0][0][prompt_len:]
hf_output_ids = hf_outputs[0][0][prompt_len:]
# check that exactly max_tokens tokens were generated with vLLM and HF
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens