[Frontend] Support configurable mm placeholder strings & flexible video sampling policies via CLI flags. (#20105)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
@ -6,8 +6,8 @@ import os
|
||||
import uuid
|
||||
from asyncio import CancelledError
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@ -32,6 +32,8 @@ class RequestOutput:
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
use_async_output_proc = True
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
class MockEngine:
|
||||
|
||||
@ -231,6 +231,58 @@ def test_limit_mm_per_prompt_parser(arg, expected):
|
||||
assert args.limit_mm_per_prompt == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("arg", "expected"),
|
||||
[
|
||||
(None, dict()),
|
||||
('{"video": {"num_frames": 123} }', {
|
||||
"video": {
|
||||
"num_frames": 123
|
||||
}
|
||||
}),
|
||||
(
|
||||
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
|
||||
{
|
||||
"video": {
|
||||
"num_frames": 123,
|
||||
"fps": 1.0,
|
||||
"foo": "bar"
|
||||
},
|
||||
"image": {
|
||||
"foo": "bar"
|
||||
}
|
||||
}),
|
||||
])
|
||||
def test_media_io_kwargs_parser(arg, expected):
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
else:
|
||||
args = parser.parse_args(["--media-io-kwargs", arg])
|
||||
|
||||
assert args.media_io_kwargs == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
(None, dict()),
|
||||
('{"video":"<|video_placeholder|>"}', {
|
||||
"video": "<|video_placeholder|>"
|
||||
}),
|
||||
('{"video":"<|video_placeholder|>", "image": "<|image_placeholder|>"}', {
|
||||
"video": "<|video_placeholder|>",
|
||||
"image": "<|image_placeholder|>"
|
||||
}),
|
||||
])
|
||||
def test_mm_placeholder_str_override_parser(arg, expected):
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
else:
|
||||
args = parser.parse_args(["--mm-placeholder-str-override", arg])
|
||||
|
||||
assert args.mm_placeholder_str_override == expected
|
||||
|
||||
|
||||
def test_compilation_config():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
@ -40,6 +40,8 @@ class MockModelConfig:
|
||||
allowed_local_media_path: str = ""
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@ -167,14 +167,14 @@ async def test_fetch_image_error_conversion():
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
async def test_fetch_video_http(video_url: str, num_frames: int):
|
||||
connector = MediaConnector()
|
||||
connector = MediaConnector(
|
||||
media_io_kwargs={"video": {
|
||||
"num_frames": num_frames,
|
||||
}})
|
||||
|
||||
video_sync = connector.fetch_video(video_url, num_frames=num_frames)
|
||||
video_async = await connector.fetch_video_async(video_url,
|
||||
num_frames=num_frames)
|
||||
# Check that the video frames are equal and metadata are same
|
||||
video_sync = connector.fetch_video(video_url)
|
||||
video_async = await connector.fetch_video_async(video_url)
|
||||
assert np.array_equal(video_sync[0], video_async[0])
|
||||
assert video_sync[1] == video_async[1]
|
||||
|
||||
|
||||
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
|
||||
|
||||
@ -4,7 +4,10 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader
|
||||
from vllm import envs
|
||||
from vllm.multimodal.image import ImageMediaIO
|
||||
from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader,
|
||||
VideoMediaIO)
|
||||
|
||||
NUM_FRAMES = 10
|
||||
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||
@ -40,3 +43,46 @@ def test_video_loader_registry():
|
||||
def test_video_loader_type_doesnt_exist():
|
||||
with pytest.raises(AssertionError):
|
||||
VIDEO_LOADER_REGISTRY.load("non_existing_video_loader")
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("assert_10_frames_1_fps")
|
||||
class Assert10Frames1FPSVideoLoader(VideoLoader):
|
||||
|
||||
@classmethod
|
||||
def load_bytes(cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: float = -1.0,
|
||||
**kwargs) -> npt.NDArray:
|
||||
assert num_frames == 10, "bad num_frames"
|
||||
assert fps == 1.0, "bad fps"
|
||||
return FAKE_OUTPUT_2
|
||||
|
||||
|
||||
def test_video_media_io_kwargs():
|
||||
envs.VLLM_VIDEO_LOADER_BACKEND = "assert_10_frames_1_fps"
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Verify that different args pass/fail assertions as expected.
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
videoio = VideoMediaIO(
|
||||
imageio, **{
|
||||
"num_frames": 10,
|
||||
"fps": 1.0,
|
||||
"not_used": "not_used"
|
||||
})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad num_frames"):
|
||||
videoio = VideoMediaIO(imageio, **{})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad num_frames"):
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad fps"):
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
@ -346,6 +346,12 @@ class ModelConfig:
|
||||
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
|
||||
"""Maximum number of data items per modality per prompt. Only applicable
|
||||
for multimodal models."""
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
"""Additional args passed to process media inputs, keyed by modalities.
|
||||
For example, to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
|
||||
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
|
||||
"""Optionally override placeholder string for given modalities."""
|
||||
use_async_output_proc: bool = True
|
||||
"""Whether to use async output processor."""
|
||||
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
|
||||
@ -694,6 +700,8 @@ class ModelConfig:
|
||||
if self.registry.is_multimodal_model(self.architectures):
|
||||
return MultiModalConfig(
|
||||
limit_per_prompt=self.limit_mm_per_prompt,
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
mm_placeholder_str_override=self.mm_placeholder_str_override,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
disable_mm_preprocessor_cache=self.
|
||||
disable_mm_preprocessor_cache)
|
||||
@ -3063,6 +3071,14 @@ class MultiModalConfig:
|
||||
`{"images": 16, "videos": 2}`
|
||||
"""
|
||||
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
"""Additional args passed to process media inputs, keyed by modalities.
|
||||
For example, to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
|
||||
|
||||
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
|
||||
"""Optionally override placeholder string for given modalities."""
|
||||
|
||||
mm_processor_kwargs: Optional[dict[str, object]] = None
|
||||
"""
|
||||
Overrides for the multi-modal processor obtained from
|
||||
|
||||
@ -369,6 +369,11 @@ class EngineArgs:
|
||||
get_field(TokenizerPoolConfig, "extra_config")
|
||||
limit_mm_per_prompt: dict[str, int] = \
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
media_io_kwargs: dict[str, dict[str,
|
||||
Any]] = get_field(MultiModalConfig,
|
||||
"media_io_kwargs")
|
||||
mm_placeholder_str_override: dict[str, str] = \
|
||||
get_field(MultiModalConfig, "mm_placeholder_str_override")
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = \
|
||||
MultiModalConfig.mm_processor_kwargs
|
||||
disable_mm_preprocessor_cache: bool = \
|
||||
@ -745,6 +750,11 @@ class EngineArgs:
|
||||
)
|
||||
multimodal_group.add_argument("--limit-mm-per-prompt",
|
||||
**multimodal_kwargs["limit_per_prompt"])
|
||||
multimodal_group.add_argument("--media-io-kwargs",
|
||||
**multimodal_kwargs["media_io_kwargs"])
|
||||
multimodal_group.add_argument(
|
||||
"--mm-placeholder-str-override",
|
||||
**multimodal_kwargs["mm_placeholder_str_override"])
|
||||
multimodal_group.add_argument(
|
||||
"--mm-processor-kwargs",
|
||||
**multimodal_kwargs["mm_processor_kwargs"])
|
||||
@ -969,6 +979,8 @@ class EngineArgs:
|
||||
enable_prompt_embeds=self.enable_prompt_embeds,
|
||||
served_model_name=self.served_model_name,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
mm_placeholder_str_override=self.mm_placeholder_str_override,
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
config_format=self.config_format,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
|
||||
@ -507,6 +507,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
|
||||
def _placeholder_str(self, modality: ModalityStr,
|
||||
current_count: int) -> Optional[str]:
|
||||
if modality in self._model_config.mm_placeholder_str_override:
|
||||
return self._model_config.mm_placeholder_str_override[modality]
|
||||
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
hf_config = self._model_config.hf_config
|
||||
@ -725,6 +728,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._tracker = tracker
|
||||
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
@ -763,7 +767,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
return self.parse_audio(audio_url)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = self._connector.fetch_video(video_url)
|
||||
video = self._connector.fetch_video(video_url=video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
@ -776,7 +780,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
self._tracker = tracker
|
||||
self._connector = MediaConnector(
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
@ -818,7 +823,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
return self.parse_audio(audio_url)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = self._connector.fetch_video_async(video_url)
|
||||
video = self._connector.fetch_video_async(video_url=video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
@ -83,6 +83,16 @@ class AudioResampler:
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(BytesIO(data), sr=None)
|
||||
|
||||
|
||||
@ -44,10 +44,16 @@ def convert_image_mode(image: Image.Image, to_mode: str):
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
|
||||
def __init__(self, *, image_mode: str = "RGB") -> None:
|
||||
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_mode = image_mode
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load_bytes(self, data: bytes) -> Image.Image:
|
||||
image = Image.open(BytesIO(data))
|
||||
|
||||
@ -38,12 +38,15 @@ class MediaConnector:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None,
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.media_io_kwargs: dict[str, dict[
|
||||
str, Any]] = media_io_kwargs if media_io_kwargs else {}
|
||||
self.connection = connection
|
||||
|
||||
if allowed_local_media_path:
|
||||
@ -149,7 +152,7 @@ class MediaConnector:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO()
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
audio_url,
|
||||
@ -164,7 +167,7 @@ class MediaConnector:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO()
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
audio_url,
|
||||
@ -183,7 +186,8 @@ class MediaConnector:
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
|
||||
try:
|
||||
return self.load_from_url(
|
||||
@ -206,7 +210,8 @@ class MediaConnector:
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
|
||||
try:
|
||||
return await self.load_from_url_async(
|
||||
@ -223,13 +228,14 @@ class MediaConnector:
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
num_frames: int = 32,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Load video from a HTTP or base64 data URL.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
video_io = VideoMediaIO(image_io,
|
||||
**self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
video_url,
|
||||
@ -242,15 +248,16 @@ class MediaConnector:
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
num_frames: int = 32,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Asynchronously load video from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
video_io = VideoMediaIO(image_io,
|
||||
**self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
video_url,
|
||||
|
||||
@ -54,7 +54,10 @@ class VideoLoader:
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
def load_bytes(cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
**kwargs) -> npt.NDArray:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -102,7 +105,8 @@ class OpenCVVideoBackend(VideoLoader):
|
||||
@classmethod
|
||||
def load_bytes(cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1) -> tuple[npt.NDArray, dict]:
|
||||
num_frames: int = -1,
|
||||
**kwargs) -> npt.NDArray:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
@ -159,18 +163,26 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
*,
|
||||
num_frames: int = 32,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
|
||||
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
||||
|
||||
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||
return self.video_loader.load_bytes(data, self.num_frames)
|
||||
return self.video_loader.load_bytes(data,
|
||||
num_frames=self.num_frames,
|
||||
**self.kwargs)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
|
||||
Reference in New Issue
Block a user