[VLM] Qwen2.5-VL

This commit is contained in:
Roger Wang
2025-02-05 13:31:38 -08:00
committed by GitHub
parent 9a5b1554b4
commit bf3b79efb8
14 changed files with 1315 additions and 52 deletions

View File

@ -846,6 +846,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `Qwen2_5_VLForConditionalGeneration`
* Qwen2.5-VL
* T + I<sup>E+</sup> + V<sup>E+</sup>
* `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc.
*
* ✅︎
* ✅︎
- * `UltravoxModel`
* Ultravox
* T + A<sup>E+</sup>
@ -880,6 +887,10 @@ The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingf
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
:::
:::{note}
To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`.
:::
### Pooling Models
See [this page](pooling-models) for more information on how to use pooling models.

View File

@ -531,6 +531,36 @@ def run_qwen2_vl(question: str, modality: str):
return llm, prompt, stop_token_ids
# Qwen2.5-VL
def run_qwen2_5_vl(question: str, modality: str):
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
llm = LLM(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = {
"aria": run_aria,
"blip-2": run_blip2,
@ -557,6 +587,7 @@ model_example_map = {
"pixtral_hf": run_pixtral_hf,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
}

View File

@ -392,6 +392,63 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
)
def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData:
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
process_vision_info = None
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
llm = LLM(
model=model_name,
max_model_len=32768 if process_vision_info is None else 4096,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages,
return_video_sample_fps=False)
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=image_data,
chat_template=None,
)
model_example_map = {
"aria": load_aria,
"deepseek_vl_v2": load_deepseek_vl2,
@ -404,6 +461,7 @@ model_example_map = {
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
"qwen2_5_vl": load_qwen2_5_vl,
}

View File

@ -121,6 +121,8 @@ VLM_TEST_SETTINGS = {
else ("half", "float")),
marks=[pytest.mark.core_model],
),
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
# once we upgraded to transformers>=4.49.0.
"qwen2_vl": VLMTestInfo(
models=["Qwen/Qwen2-VL-2B-Instruct"],
test_type=(
@ -138,6 +140,26 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen2_5_vl": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO
),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.skipif(
TRANSFORMERS_VERSION < "4.49.0",
reason="HF model requires transformers>=4.49.0",
), pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],

View File

@ -161,6 +161,7 @@ def _test_processing_correctness(
"nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_3",
])

View File

@ -264,6 +264,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
trust_remote_code=True),
# [Encoder-decoder]

View File

@ -410,7 +410,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<image>"
if model_type == "mllama":
return "<|image|>"
if model_type == "qwen2_vl":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
@ -430,7 +430,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)"

View File

@ -27,6 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp
@ -772,8 +773,12 @@ class MRotaryEmbedding(RotaryEmbedding):
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self.cache_max_position_num = max_position_embeddings * 4
super().__init__(head_size, rotary_dim, self.cache_max_position_num,
base, is_neox_style, dtype)
self.mrope_section = mrope_section
if self.mrope_section:
@ -831,13 +836,10 @@ class MRotaryEmbedding(RotaryEmbedding):
@staticmethod
def get_input_positions(
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[List[List[int]], int]:
@ -845,16 +847,13 @@ class MRotaryEmbedding(RotaryEmbedding):
llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
input_tokens,
image_grid_thw,
video_grid_thw,
image_token_id,
video_token_id,
vision_start_token_id,
vision_end_token_id,
spatial_merge_size,
context_len,
seq_len,
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)
return llm_positions.tolist(), mrope_position_delta
@ -862,18 +861,22 @@ class MRotaryEmbedding(RotaryEmbedding):
@staticmethod
def get_input_positions_tensor(
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
@ -892,6 +895,7 @@ class MRotaryEmbedding(RotaryEmbedding):
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
@ -915,9 +919,13 @@ class MRotaryEmbedding(RotaryEmbedding):
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_second_per_grid_t = 1.0
if second_per_grid_ts is not None:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
@ -927,8 +935,10 @@ class MRotaryEmbedding(RotaryEmbedding):
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
tokens_per_second).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(

File diff suppressed because it is too large Load Diff

View File

@ -650,8 +650,8 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data, modality)
@ -683,26 +683,26 @@ class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
return self.data
class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):
class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "image")
class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):
class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "video")
class Qwen2MultiModalDataParser(MultiModalDataParser):
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="image")
return Qwen2VLEmbeddingItems(data, modality="image")
return super()._parse_image_data(data)
@ -711,7 +711,7 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="video")
return Qwen2VLEmbeddingItems(data, modality="video")
return super()._parse_video_data(data)
@ -948,7 +948,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
return Qwen2VLMultiModalDataParser()
def _get_prompt_replacements(
self,

View File

@ -172,6 +172,7 @@ _MULTIMODAL_MODELS = {
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]

View File

@ -285,6 +285,7 @@ class GPUModelRunner:
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
second_per_grid_ts = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
@ -292,6 +293,9 @@ class GPUModelRunner:
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None:
second_per_grid_ts.extend(
mm_input["second_per_grid_ts"])
hf_config = self.model_config.hf_config
@ -299,14 +303,10 @@ class GPUModelRunner:
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
second_per_grid_ts=second_per_grid_ts,
)
req_ids_to_add.append(req_id)

View File

@ -386,20 +386,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()
mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
second_per_grid_ts=second_per_grid_ts,
context_len=computed_len,
)
seq_data.mrope_position_delta = mrope_position_delta

View File

@ -702,6 +702,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
hf_config = self.runner.model_config.hf_config
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
@ -713,14 +714,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
second_per_grid_ts=second_per_grid_ts,
context_len=inter_data.context_lens[seq_idx],
seq_len=inter_data.seq_lens[seq_idx],
)