committed by
GitHub
parent
775f00f81e
commit
d394787e52
64
tests/models/test_pixtral.py
Normal file
64
tests/models/test_pixtral.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_mistral.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
MODELS = ["mistralai/Pixtral-12B-2409"]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
image_urls = [
|
||||
"https://picsum.photos/id/237/200/300",
|
||||
"https://picsum.photos/seed/picsum/200/300"
|
||||
]
|
||||
expected = [
|
||||
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
|
||||
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
|
||||
]
|
||||
prompt = "Describe the image in one short sentence."
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
|
||||
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
tokenizer_mode="mistral") as vllm_model:
|
||||
|
||||
for i, image_url in enumerate(image_urls):
|
||||
messages = [
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}]
|
||||
},
|
||||
]
|
||||
|
||||
outputs = vllm_model.model.chat(messages,
|
||||
sampling_params=sampling_params)
|
||||
assert outputs[0].outputs[0].text == expected[i]
|
||||
Reference in New Issue
Block a user