[VLM] Qwen2.5-VL
This commit is contained in:
@ -846,6 +846,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `Qwen2_5_VLForConditionalGeneration`
|
||||
* Qwen2.5-VL
|
||||
* T + I<sup>E+</sup> + V<sup>E+</sup>
|
||||
* `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `UltravoxModel`
|
||||
* Ultravox
|
||||
* T + A<sup>E+</sup>
|
||||
@ -880,6 +887,10 @@ The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingf
|
||||
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`.
|
||||
:::
|
||||
|
||||
### Pooling Models
|
||||
|
||||
See [this page](pooling-models) for more information on how to use pooling models.
|
||||
|
||||
@ -531,6 +531,36 @@ def run_qwen2_vl(question: str, modality: str):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Qwen2.5-VL
|
||||
def run_qwen2_5_vl(question: str, modality: str):
|
||||
|
||||
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 28 * 28,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
"fps": 1,
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
placeholder = "<|image_pad|>"
|
||||
elif modality == "video":
|
||||
placeholder = "<|video_pad|>"
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"aria": run_aria,
|
||||
"blip-2": run_blip2,
|
||||
@ -557,6 +587,7 @@ model_example_map = {
|
||||
"pixtral_hf": run_pixtral_hf,
|
||||
"qwen_vl": run_qwen_vl,
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"qwen2_5_vl": run_qwen2_5_vl,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -392,6 +392,63 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||
try:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
except ModuleNotFoundError:
|
||||
print('WARNING: `qwen-vl-utils` not installed, input images will not '
|
||||
'be automatically resized. You can enable this functionality by '
|
||||
'`pip install qwen-vl-utils`.')
|
||||
process_vision_info = None
|
||||
|
||||
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=32768 if process_vision_info is None else 4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
stop_token_ids = None
|
||||
|
||||
if process_vision_info is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
else:
|
||||
image_data, _ = process_vision_info(messages,
|
||||
return_video_sample_fps=False)
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=image_data,
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"aria": load_aria,
|
||||
"deepseek_vl_v2": load_deepseek_vl2,
|
||||
@ -404,6 +461,7 @@ model_example_map = {
|
||||
"pixtral_hf": load_pixtral_hf,
|
||||
"qwen_vl_chat": load_qwen_vl_chat,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen2_5_vl": load_qwen2_5_vl,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -121,6 +121,8 @@ VLM_TEST_SETTINGS = {
|
||||
else ("half", "float")),
|
||||
marks=[pytest.mark.core_model],
|
||||
),
|
||||
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
|
||||
# once we upgraded to transformers>=4.49.0.
|
||||
"qwen2_vl": VLMTestInfo(
|
||||
models=["Qwen/Qwen2-VL-2B-Instruct"],
|
||||
test_type=(
|
||||
@ -138,6 +140,26 @@ VLM_TEST_SETTINGS = {
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
"qwen2_5_vl": VLMTestInfo(
|
||||
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
|
||||
test_type=(
|
||||
VLMTestType.IMAGE,
|
||||
VLMTestType.MULTI_IMAGE,
|
||||
VLMTestType.VIDEO
|
||||
),
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
|
||||
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.skipif(
|
||||
TRANSFORMERS_VERSION < "4.49.0",
|
||||
reason="HF model requires transformers>=4.49.0",
|
||||
), pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
#### Extended model tests
|
||||
"aria": VLMTestInfo(
|
||||
models=["rhymes-ai/Aria"],
|
||||
|
||||
@ -161,6 +161,7 @@ def _test_processing_correctness(
|
||||
"nvidia/NVLM-D-72B",
|
||||
"Qwen/Qwen-VL-Chat",
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"fixie-ai/ultravox-v0_3",
|
||||
])
|
||||
|
||||
@ -264,6 +264,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||
min_transformers_version="4.49"), # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
|
||||
@ -410,7 +410,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
return "<image>"
|
||||
if model_type == "mllama":
|
||||
return "<|image|>"
|
||||
if model_type == "qwen2_vl":
|
||||
if model_type in ("qwen2_vl", "qwen2_5_vl"):
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if model_type == "molmo":
|
||||
return ""
|
||||
@ -430,7 +430,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
return "(<audio>./</audio>)"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "video":
|
||||
if model_type == "qwen2_vl":
|
||||
if model_type in ("qwen2_vl", "qwen2_5_vl"):
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
if model_type in ("minicpmo", "minicpmv"):
|
||||
return "(<video>./</video>)"
|
||||
|
||||
@ -27,6 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
@ -772,8 +773,12 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
dtype: torch.dtype,
|
||||
mrope_section: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||
# a larger the cos and sin cache.
|
||||
self.cache_max_position_num = max_position_embeddings * 4
|
||||
super().__init__(head_size, rotary_dim, self.cache_max_position_num,
|
||||
base, is_neox_style, dtype)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
if self.mrope_section:
|
||||
@ -831,13 +836,10 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
@staticmethod
|
||||
def get_input_positions(
|
||||
input_tokens: List[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
vision_start_token_id: int,
|
||||
vision_end_token_id: int,
|
||||
spatial_merge_size: int,
|
||||
second_per_grid_ts: Optional[List[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> Tuple[List[List[int]], int]:
|
||||
@ -845,16 +847,13 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
llm_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
input_tokens,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
image_token_id,
|
||||
video_token_id,
|
||||
vision_start_token_id,
|
||||
vision_end_token_id,
|
||||
spatial_merge_size,
|
||||
context_len,
|
||||
seq_len,
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
|
||||
return llm_positions.tolist(), mrope_position_delta
|
||||
@ -862,18 +861,22 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
@staticmethod
|
||||
def get_input_positions_tensor(
|
||||
input_tokens: List[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
vision_start_token_id: int,
|
||||
vision_end_token_id: int,
|
||||
spatial_merge_size: int,
|
||||
second_per_grid_ts: Optional[List[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(hf_config.vision_config,
|
||||
"tokens_per_second", 1.0)
|
||||
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
if isinstance(video_grid_thw, torch.Tensor):
|
||||
@ -892,6 +895,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
video_second_per_grid_t = 0.0
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
@ -915,9 +919,13 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_second_per_grid_t = 1.0
|
||||
if second_per_grid_ts is not None:
|
||||
video_second_per_grid_t = second_per_grid_ts[video_index]
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
text_len = ed - st
|
||||
@ -927,8 +935,10 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
|
||||
tokens_per_second).long().flatten()
|
||||
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
|
||||
1133
vllm/model_executor/models/qwen2_5_vl.py
Normal file
1133
vllm/model_executor/models/qwen2_5_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -650,8 +650,8 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor]]):
|
||||
class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor]]):
|
||||
|
||||
def __init__(self, data: dict, modality: str) -> None:
|
||||
super().__init__(data, modality)
|
||||
@ -683,26 +683,26 @@ class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
return self.data
|
||||
|
||||
|
||||
class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):
|
||||
class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems):
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
|
||||
class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):
|
||||
class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems):
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
|
||||
class Qwen2MultiModalDataParser(MultiModalDataParser):
|
||||
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return Qwen2EmbeddingItems(data, modality="image")
|
||||
return Qwen2VLEmbeddingItems(data, modality="image")
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
@ -711,7 +711,7 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return Qwen2EmbeddingItems(data, modality="video")
|
||||
return Qwen2VLEmbeddingItems(data, modality="video")
|
||||
|
||||
return super()._parse_video_data(data)
|
||||
|
||||
@ -948,7 +948,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return Qwen2MultiModalDataParser()
|
||||
return Qwen2VLMultiModalDataParser()
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
|
||||
@ -172,6 +172,7 @@ _MULTIMODAL_MODELS = {
|
||||
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
# [Encoder-decoder]
|
||||
|
||||
@ -285,6 +285,7 @@ class GPUModelRunner:
|
||||
if self.model_config.uses_mrope:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
second_per_grid_ts = []
|
||||
for mm_input in self.requests[req_id].mm_inputs:
|
||||
if mm_input.get("image_grid_thw") is not None:
|
||||
image_grid_thw.extend(
|
||||
@ -292,6 +293,9 @@ class GPUModelRunner:
|
||||
if mm_input.get("video_grid_thw") is not None:
|
||||
video_grid_thw.extend(
|
||||
mm_input["video_grid_thw"].tolist())
|
||||
if mm_input.get("second_per_grid_ts") is not None:
|
||||
second_per_grid_ts.extend(
|
||||
mm_input["second_per_grid_ts"])
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
|
||||
@ -299,14 +303,10 @@ class GPUModelRunner:
|
||||
self.requests[req_id].mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
self.requests[req_id].prompt_token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
)
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
@ -386,20 +386,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
"mrope embedding type requires multi-modal input mapper "
|
||||
"returns 'image_grid_thw' or 'video_grid_thw'.")
|
||||
|
||||
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
||||
hf_config = self.runner.model_config.hf_config
|
||||
token_ids = seq_data.get_token_ids()
|
||||
|
||||
mrope_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=computed_len,
|
||||
)
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
|
||||
@ -702,6 +702,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
"mrope embedding type requires multi-modal input mapper "
|
||||
"returns 'image_grid_thw' or 'video_grid_thw'.")
|
||||
|
||||
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
||||
hf_config = self.runner.model_config.hf_config
|
||||
|
||||
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
|
||||
@ -713,14 +714,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
mrope_input_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
seq_len=inter_data.seq_lens[seq_idx],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user