[VLM] Report multi_modal_placeholders in output (#10407)
Signed-off-by: Linkun Chen <lkchen+anyscale@github.com>
This commit is contained in:
@ -8,13 +8,17 @@ from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from mistral_common.multimodal import download_image
|
||||
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
||||
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams,
|
||||
TextPrompt, TokensPrompt)
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sequence import Logprob, SampleLogprobs
|
||||
|
||||
from ....utils import VLLM_PATH, large_gpu_test
|
||||
@ -49,6 +53,20 @@ def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
|
||||
}]
|
||||
|
||||
|
||||
def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]:
|
||||
return [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"content": PROMPT,
|
||||
}, *({
|
||||
"type": "image",
|
||||
"image": download_image(url)
|
||||
} for url in urls)],
|
||||
}]
|
||||
|
||||
|
||||
def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
|
||||
msg = _create_msg_format(urls)
|
||||
|
||||
@ -70,6 +88,23 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
|
||||
return engine_inputs
|
||||
|
||||
|
||||
def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt:
|
||||
msg = _create_msg_format_hf(urls)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b")
|
||||
prompt = tokenizer.apply_chat_template(msg)
|
||||
|
||||
images = []
|
||||
for chunk in msg[0]["content"]:
|
||||
if chunk["type"] == "image":
|
||||
images.append(chunk["image"])
|
||||
|
||||
mm_data = MultiModalDataBuiltins(image=images)
|
||||
engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data)
|
||||
|
||||
return engine_inputs
|
||||
|
||||
|
||||
MSGS = [
|
||||
_create_msg_format(IMG_URLS[:1]),
|
||||
_create_msg_format(IMG_URLS[:2]),
|
||||
@ -191,3 +226,45 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||
outputs_1_lst=logprobs,
|
||||
name_0="h100_ref",
|
||||
name_1="output")
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=24)
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,expected_ranges",
|
||||
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
|
||||
"offset": 10,
|
||||
"length": 494
|
||||
}]),
|
||||
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
|
||||
"offset": 10,
|
||||
"length": 266
|
||||
}, {
|
||||
"offset": 276,
|
||||
"length": 1056
|
||||
}, {
|
||||
"offset": 1332,
|
||||
"length": 418
|
||||
}])])
|
||||
def test_multi_modal_placeholders(
|
||||
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None:
|
||||
with vllm_runner(
|
||||
"mistral-community/pixtral-12b",
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.model.generate(prompt)
|
||||
|
||||
assert len(outputs) == 1, f"{len(outputs)=}"
|
||||
output: RequestOutput = outputs[0]
|
||||
assert hasattr(output,
|
||||
"multi_modal_placeholders"), f"{output.__dict__=}"
|
||||
assert "image" in output.multi_modal_placeholders, \
|
||||
f"{output.multi_modal_placeholders.keys()=}"
|
||||
image_placeholder_ranges: list[
|
||||
PlaceholderRange] = output.multi_modal_placeholders["image"]
|
||||
assert len(image_placeholder_ranges) == len(
|
||||
expected_ranges), f"{image_placeholder_ranges=}"
|
||||
for real_range, expected_range in zip(image_placeholder_ranges,
|
||||
expected_ranges):
|
||||
assert real_range == expected_range, \
|
||||
f"{real_range=} {expected_range=}"
|
||||
|
||||
Reference in New Issue
Block a user