[Model]: support KeyeVL-1_5-8B (#23838)

Signed-off-by: wangruitao <wangruitao@kuaishou.com>
Co-authored-by: wangruitao <wangruitao@kuaishou.com>
This commit is contained in:
Kwai-Keye
2025-09-01 18:50:27 +08:00
committed by GitHub
parent 3e330fcb21
commit 7c8271cd1e
9 changed files with 1123 additions and 278 deletions

View File

@ -634,7 +634,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |

View File

@ -683,6 +683,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData:
)
# Keye-VL-1.5
def run_keye_vl1_5(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1.5-8B"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
trust_remote_code=True,
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"
prompts = [
(
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Kimi-VL
def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -1648,6 +1679,7 @@ model_example_map = {
"interns1": run_interns1,
"internvl_chat": run_internvl,
"keye_vl": run_keye_vl,
"keye_vl1_5": run_keye_vl1_5,
"kimi_vl": run_kimi_vl,
"llama4": run_llama4,
"llava": run_llava,

View File

@ -542,6 +542,43 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
@ -1209,6 +1246,7 @@ model_example_map = {
"interns1": load_interns1,
"internvl_chat": load_internvl,
"keye_vl": load_keye_vl,
"keye_vl1_5": load_keye_vl1_5,
"kimi_vl": load_kimi_vl,
"llama4": load_llama4,
"llava": load_llava,

View File

@ -293,6 +293,7 @@ def _test_processing_correctness_one(
"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview",
"OpenGVLab/InternVL3_5-30B-A3B",
"Kwai-Keye/Keye-VL-8B-Preview",
"Kwai-Keye/Keye-VL-1_5-8B",
"moonshotai/Kimi-VL-A3B-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf",

View File

@ -438,6 +438,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501
trust_remote_code=True),
"KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-1_5-8B", # noqa: E501
trust_remote_code=True),
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True),

View File

@ -402,6 +402,15 @@ class MRotaryEmbedding(RotaryEmbedding):
context_len=context_len,
seq_len=seq_len,
)
elif "KeyeVL1_5" in hf_config.model_type:
return cls._keye_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
@ -636,6 +645,126 @@ class MRotaryEmbedding(RotaryEmbedding):
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _keye_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(
grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w)).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _vl_get_input_positions_tensor(
cls,

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, Optional, Union
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
import numpy as np
import torch
@ -57,16 +58,13 @@ from .vision import get_vit_attn_backend
logger = init_logger(__name__)
_MAX_FRAMES_PER_VIDEO = 16
_MAX_IMAGE_SIZE = 9999999
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 28 * 28 * 130,
max_pixels: int = 28 * 28 * 1280,
factor: int,
min_pixels: int,
max_pixels: int,
):
if height < factor:
logger.warning(
@ -887,9 +885,9 @@ class Projector(nn.Module):
def forward(
self,
image_features: torch.Tensor,
image_features: Union[torch.Tensor, list[torch.Tensor]],
image_grid_thw: list[tuple[int, int, int]],
) -> torch.Tensor:
) -> Union[torch.Tensor, list[torch.Tensor]]:
m1, m2 = self.merge_kernel_size
if isinstance(image_features, (list, tuple)):
processed_features = list()
@ -986,6 +984,12 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
class KeyeProcessingInfo(BaseProcessingInfo):
def get_max_image_size(self) -> int:
return 9999999 #_MAX_IMAGE_SIZE
def get_max_frame_per_video(self) -> int:
return 16 #_MAX_FRAMES_PER_VIDEO
def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor
@ -1077,8 +1081,8 @@ class KeyeProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self, ) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=_MAX_IMAGE_SIZE,
image_height=_MAX_IMAGE_SIZE,
image_width=self.get_max_image_size(),
image_height=self.get_max_image_size(),
image_processor=None,
)
return max_image_size
@ -1123,7 +1127,7 @@ class KeyeProcessingInfo(BaseProcessingInfo):
max_image_tokens)
max_frames_per_video = min(
max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO,
self.get_max_frame_per_video(),
)
return max(max_frames_per_video, 1)
@ -1139,7 +1143,10 @@ class KeyeProcessingInfo(BaseProcessingInfo):
)
class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]):
_I = TypeVar("_I", bound=KeyeProcessingInfo)
class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@ -1183,6 +1190,10 @@ class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]):
return mm_data
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
...
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
@ -1231,13 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
return _keye_field_config(hf_inputs)
@MULTIMODAL_REGISTRY.register_processor(
KeyeMultiModalProcessor,
info=KeyeProcessingInfo,
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
SupportsPP):
class BaseKeyeModule(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -1264,6 +1269,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
raise ValueError("Only image or video modality is supported")
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
@ -1278,7 +1288,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
self.mlp_AR = Projector(
self.mlp_AR = self._build_projector(
config,
config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
@ -1294,13 +1305,287 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
@abstractmethod
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
raise ValueError("Need projector")
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
name: str) -> torch.Tensor:
def _process_image_input(self,
image_input: Any) -> tuple[torch.Tensor, ...]:
siglip_position_ids = list()
image_grid_hws = list()
sample_indices = list()
cu_seqlens = [0]
image_grid_thw = image_input["image_grid_thw"]
assert image_grid_thw.ndim == 2
for idx, thaw in enumerate(image_grid_thw):
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple)
image_grid_hws.append(thw_tuple)
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
siglip_position_ids.append(image_position_ids)
sample_indices.append(torch.full((numel, ), idx,
dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel)
if image_input["type"] == "image_embeds":
raise ValueError(
"Image embeddings are not supported for this processing path.")
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
siglip_position_ids = torch.concat(siglip_position_ids,
dim=0).to(pixel_values.device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
pixel_values.device)
sample_indices = torch.concat(sample_indices,
dim=0).to(pixel_values.device)
image_embeds = self.visual(
pixel_values=pixel_values,
image_grid_thw=image_grid_hws,
position_ids=siglip_position_ids,
vision_return_embed_list=False,
interpolate_pos_encoding=True,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
return image_embeds
def _process_video_embeds(
self,
video_type: Literal["video_embeds", "pixel_values_videos"],
video_grid_thw: list[torch.Tensor],
pixel_values_videos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
siglip_position_ids = list()
video_grid_hws = list()
sample_indices = list()
cu_seqlens = [0]
assert video_grid_thw.ndim == 2
for idx, sub_thw in enumerate(video_grid_thw):
thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple)
video_grid_hws.append(thw_tuple)
video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
siglip_position_ids.append(video_position_ids)
sample_indices.append(torch.full((numel, ), idx,
dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel)
if video_type == "video_embeds":
raise ValueError(
"Video embeddings are not supported for this processing path.")
else:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
pixel_values_videos.device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
pixel_values_videos.device)
sample_indices = torch.concat(sample_indices,
dim=0).to(pixel_values_videos.device)
video_embeds = self.visual(
pixel_values=pixel_values_videos,
image_grid_thw=video_grid_hws,
position_ids=siglip_position_ids,
vision_return_embed_list=True,
interpolate_pos_encoding=True,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
return video_embeds
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
for input_key in kwargs:
if (input_key in ("pixel_values", "image_embeds")
and "images" not in modalities):
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if (input_key in ("pixel_values_videos", "video_embeds")
and "videos" not in modalities):
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
[
self.config.image_token_id,
self.config.video_token_id,
],
)
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[Any] = None,
video_input: Optional[Any] = None,
) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Keye-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
inputs_embeds = None
else:
if uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.get_input_embeddings_v0(
input_ids,
image_input=image_input,
video_input=video_input,
)
input_ids = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""Get the module prefix in multimodal models."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="mlp_AR.",
tower_model="visual.",
)
@MULTIMODAL_REGISTRY.register_processor(
KeyeMultiModalProcessor,
info=KeyeProcessingInfo,
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
SupportsLoRA, SupportsPP):
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
return Projector(text_config, vision_config, quant_config, prefix)
def _validate_and_reshape_mm_tensor(
self, mm_input: NestedTensors,
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
@ -1388,257 +1673,12 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
video_grid_thw=video_grid_thw,
)
def _process_image_input(
self, image_input: KeyeImageInputs) -> tuple[torch.Tensor, ...]:
siglip_position_ids = list()
image_grid_hws = list()
sample_indices = list()
cu_seqlens = [0]
image_grid_thw = image_input["image_grid_thw"]
assert image_grid_thw.ndim == 2
for idx, thaw in enumerate(image_grid_thw):
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple)
image_grid_hws.append(thw_tuple)
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
siglip_position_ids.append(image_position_ids)
sample_indices.append(torch.full((numel, ), idx,
dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel)
if image_input["type"] == "image_embeds":
raise ValueError(
"Image embeddings are not supported for this processing path.")
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
siglip_position_ids = torch.concat(siglip_position_ids,
dim=0).to(pixel_values.device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
pixel_values.device)
sample_indices = torch.concat(sample_indices,
dim=0).to(pixel_values.device)
image_embeds = self.visual(
pixel_values=pixel_values,
image_grid_thw=image_grid_hws,
position_ids=siglip_position_ids,
vision_return_embed_list=False,
interpolate_pos_encoding=True,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
return image_embeds
def _process_video_input(
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
siglip_position_ids = list()
video_grid_hws = list()
sample_indices = list()
cu_seqlens = [0]
video_type = video_input["type"]
video_grid_thw = video_input["video_grid_thw"]
assert video_grid_thw.ndim == 2
pixel_values_videos = video_input.get("pixel_values_videos", None)
for idx, thaw in enumerate(video_grid_thw):
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple)
video_grid_hws.append(thw_tuple)
video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
siglip_position_ids.append(video_position_ids)
sample_indices.append(torch.full((numel, ), idx,
dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel)
if video_input["type"] == "video_embeds":
raise ValueError(
"Video embeddings are not supported for this processing path.")
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
pixel_values_videos.device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
pixel_values_videos.device)
sample_indices = torch.concat(sample_indices,
dim=0).to(pixel_values_videos.device)
video_embeds = self.visual(
pixel_values=pixel_values_videos,
image_grid_thw=video_grid_hws,
position_ids=siglip_position_ids,
vision_return_embed_list=True,
interpolate_pos_encoding=True,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
video_embeds = tuple(self.mlp_AR(video_embeds, video_grid_thw))
return video_embeds
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
for input_key in kwargs:
if (input_key in ("pixel_values", "image_embeds")
and "images" not in modalities):
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if (input_key in ("pixel_values_videos", "video_embeds")
and "videos" not in modalities):
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
[
self.config.image_token_id,
self.config.video_token_id,
],
)
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[KeyeImagePixelInputs] = None,
video_input: Optional[KeyeVideoPixelInputs] = None,
) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Qwen2-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
inputs_embeds = None
else:
if uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.get_input_embeddings_v0(
input_ids,
image_input=image_input,
video_input=video_input,
)
input_ids = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""Get the module prefix in multimodal models."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="visual.",
tower_model="mlp_AR.",
)
return tuple(
self._process_video_embeds(video_type, video_grid_thw,
pixel_values_videos))

View File

@ -0,0 +1,601 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
from transformers.feature_extraction_utils import BatchFeature
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig,
MultiModalKwargsItems, VideoItem)
from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .keye import (BaseKeyeModule, BaseMultiModalProcessor,
KeyeBaseDummyInputsBuilder, KeyeProcessingInfo)
logger = init_logger(__name__)
def split_thw(grid_thw: torch.Tensor) -> torch.Tensor:
"""
Split grid_thw in t dimension.
Args:
grid_thw: [N, 3] tensor of [t, h, w]
Returns:
[Σt, 3] tensor where each row is [1, h, w]
Example:
>>> grid_thw = torch.tensor([[2, 3, 4], [1, 5, 6]])
>>> split_thw(grid_thw)
tensor([[1, 3, 4],
[1, 3, 4],
[1, 5, 6]])
"""
t = grid_thw[:, 0]
h_w = grid_thw[:, 1:]
ones = torch.ones_like(h_w[:, :1])
return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0)
def get_num_patches(grid_thw: torch.Tensor, num_frames: Union[list[int],
torch.Tensor]):
"""
Return num_patches per video.
Args:
t: tensor with shape [N, ...] where each item is a list/tensor
cu_seqlens: list indicating the boundaries of groups
Returns:
list of ints representing the sum of products for each group
Examples:
>>> # Suppose there are 2 videos with a total of 3 grids
>>> grid_thw = torch.tensor([[2, 2, 2], # grid 0: 2*2*2=8 patches
... [2, 2, 2], # grid 1: 2*2*2=8 patches
... [1, 1, 1]]) # grid 2: 1*1*1=1 patches
>>> num_frames = [2, 1] # The first video contains 2 grids,
the second contains 1 grid.
>>> get_num_patches(grid_thw, num_frames)
tensor([16, 1]) # Total patches for first video: 8+8=16,
second video: 1.
"""
assert len(grid_thw.shape) == 2
if isinstance(num_frames, torch.Tensor):
num_frames = num_frames.clone().tolist()
num_grids_per_frame = grid_thw.prod(dim=1)
start_idx_per_video = [0, *itertools.accumulate(num_frames)]
num_patches = [
num_grids_per_frame[start_idx_per_video[i]:start_idx_per_video[i + 1]].
sum() for i in range(len(num_frames))
]
return torch.stack(num_patches) if num_patches else torch.zeros(
0, dtype=grid_thw.dtype, device=grid_thw.device)
class KeyeVL1_5ImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values"]
pixel_values: Annotated[
torch.Tensor,
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
class KeyeVL1_5ImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size (must match the hidden size of language model
backbone)
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["image_embeds"]
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs,
KeyeVL1_5ImageEmbeddingInputs]
class KeyeVL1_5VideoPixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
num_frames: torch.Tensor
class KeyeVL1_5VideoEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of video features
- hs: Hidden size (must match the hidden size of language model
backbone)
- nv: Number of videos
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["video_embeds"]
video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
num_frames: torch.Tensor
KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs,
KeyeVL1_5VideoEmbeddingInputs]
class KeyeVL1_5Projector(nn.Module):
def __init__(
self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.text_config = text_config
self.vision_config = vision_config
self.merge_kernel_size = (2, 2)
self.hidden_size = (self.vision_config.hidden_size *
self.merge_kernel_size[0] *
self.merge_kernel_size[1])
self.pre_norm = torch.nn.LayerNorm(self.hidden_size, eps=1e-05)
self.act = GELUActivation()
self.linear_1 = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_1",
)
self.linear_2 = RowParallelLinear(
self.hidden_size,
self.text_config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_2",
)
def forward(
self,
image_features: Union[torch.Tensor, tuple[torch.Tensor],
list[torch.Tensor]],
image_grid_thw: list[tuple[int, int, int]],
) -> Union[torch.Tensor, list[torch.Tensor]]:
m1, m2 = self.merge_kernel_size
if isinstance(image_features, (list, tuple)):
processed_features = list()
for image_feature, image_grid in zip(image_features,
image_grid_thw):
t, h, w = image_grid
image_feature = rearrange(
image_feature,
"(t h p1 w p2) d -> (t h w) (p1 p2 d)",
t=t,
h=h // m1,
p1=m1,
w=w // m2,
p2=m2,
)
image_feature = self.pre_norm(image_feature)
hidden_states, _ = self.linear_1(image_feature)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
processed_features.append(hidden_states)
return processed_features
dims = image_features.shape[:-1]
dim = image_features.shape[-1]
image_features = image_features.view(np.prod(dims), dim)
hidden_states = self.pre_norm(image_features.view(
-1, self.hidden_size))
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states.view(*dims, -1)
class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo):
def get_max_frame_per_video(self) -> int:
return 2048
def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]:
return {"image": None, "video": 1}
def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ):
image_grid_thw = hf_inputs.get("image_grid_thw",
torch.empty((0, 3), dtype=torch.int64))
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw",
torch.empty((0, 3), dtype=torch.int64))
video_grid_thw = split_thw(video_grid_thw)
num_frames = hf_inputs.get("num_frames",
video_grid_thw[:, 0]).clone().tolist()
video_num_patches = get_num_patches(video_grid_thw, num_frames)
video_num_grids = []
if len(num_frames) > 0:
i = 0
j = 1
cur_frames = num_frames[i]
for t, _, _ in video_grid_thw.tolist():
cur_frames -= t
if cur_frames == 0:
video_num_grids.append(j)
i += 1
if i < len(num_frames):
cur_frames = num_frames[i]
j = 1
else:
j += 1
video_num_grids = torch.tensor(video_num_grids)
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches),
video_grid_thw=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_grids),
num_frames=MultiModalFieldConfig.batched("video"))
class KeyeVL1_5MultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="image",
required_fields={
"image_embeds",
"image_grid_thw",
},
fields_factory=_keye_field_config,
)
return super()._parse_image_data(data)
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="video",
required_fields={
"video_embeds",
"video_grid_thw",
},
fields_factory=_keye_field_config,
)
return super()._parse_video_data(data)
class KeyeVL1_5MultiModalProcessor(
BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return KeyeVL1_5MultiModalDataParser()
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token]
video_token_id = vocab[hf_processor.video_token]
placeholder = {"image": image_token_id, "video": video_token_id}
merge_length = image_processor.merge_size**2
out_mm_kwargs_data = out_mm_kwargs.get_data()
frame_types: list[torch.Tensor] = \
hf_processor_mm_kwargs.get("frame_types", None)
timestamps: list[torch.Tensor] = \
hf_processor_mm_kwargs.get("timestamps", None)
num_videos = mm_items.get_count("video", strict=False)
if frame_types is None:
frame_types = [None] * num_videos
assert len(frame_types) == num_videos, \
f"Number of frame_types={len(frame_types)} " \
f"doesn't equal to number of videos={num_videos}"
if timestamps is None:
timestamps = [None] * num_videos
assert len(timestamps) == num_videos, \
f"Number of timestamps={len(timestamps)} " \
f"doesn't equal to number of videos={num_videos}"
video_grid_thw = out_mm_kwargs_data.get(
'video_grid_thw', torch.empty((0, 3), dtype=torch.int64))
num_frames = out_mm_kwargs_data.get(
'num_frames', torch.tensor([], dtype=torch.int64))
assert len(num_frames) == num_videos, \
f"Size of num_frames={len(num_frames)} " \
f"doesn't equal to number of videos={num_videos}"
video_grid_hws = split_thw(video_grid_thw)
assert int(num_frames.sum().tolist()) == video_grid_hws.shape[0], (
f"The first dimension of `video_grid_hws`={video_grid_hws.shape[0]}"
f"doesn't equal to num of frames.")
cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()),
dim=-1)
def get_replacement_keye(item_idx: int, modality: str):
"""
Args:
item_idx(int): The item index of modality to replace
modality(str): The modality
"""
if modality == "image":
out_item = out_mm_kwargs[modality][item_idx]
grid_thw = out_item[f"{modality}_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
num_tokens = int(grid_thw.prod()) // merge_length
return [image_token_id] * num_tokens
elif modality == "video":
placeholders = []
video_timestamps = timestamps[item_idx]
video_frame_types = frame_types[item_idx]
grid_thw = video_grid_hws[
cu_seqlens[item_idx]:cu_seqlens[item_idx + 1]]
nframes = grid_thw.shape[0]
if video_timestamps is None:
video_timestamps = [""] * nframes
else:
video_timestamps = [
format(ts, ".1f") for ts in video_timestamps
]
if video_frame_types is None:
video_frame_types = [0] * nframes
for i, sub_thw in enumerate(grid_thw):
s = f"{hf_processor.frame_token}{video_timestamps[i]}"
if video_frame_types[i] == 1:
s += hf_processor.fast_start
placeholders.extend(tokenizer.encode(s))
num_frame_tokens = int(sub_thw.prod()) // merge_length
placeholders.extend([video_token_id] * num_frame_tokens)
if video_frame_types[i] == 1:
placeholders.append(vocab[hf_processor.fast_end])
return PromptUpdateDetails.select_token_id(
placeholders, embed_token_id=video_token_id)
else:
raise ValueError(f"Unsupported modality {modality}")
return [
PromptReplacement(
modality=modality,
target=[placeholder[modality]],
replacement=partial(get_replacement_keye, modality=modality),
) for modality in ("image", "video")
]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _keye_field_config(hf_inputs)
class KeyeVL1_5DummyInputsBuilder(
KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo]):
...
@MULTIMODAL_REGISTRY.register_processor(
KeyeVL1_5MultiModalProcessor,
info=KeyeVL1_5ProcessingInfo,
dummy_inputs=KeyeVL1_5DummyInputsBuilder,
)
class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
SupportsLoRA, SupportsPP):
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
return KeyeVL1_5Projector(text_config, vision_config, quant_config,
prefix)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig = vllm_config.model_config.hf_config
self.merge_size = config.vision_config.spatial_merge_size
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
expected_dim: int, name: str):
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == expected_dim:
return mm_input
elif mm_input.ndim == expected_dim + 1:
return torch.concat(list(mm_input))
else:
raise ValueError(
f"{name} should be {expected_dim}D or "
f"batched {expected_dim}D tensor."
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
else:
return torch.concat(list(mm_input))
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, expected_dim=4, name="image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, expected_dim=2, name="image grid_thw")
return KeyeVL1_5ImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, expected_dim=2, name="image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, expected_dim=2, name="image grid_thw")
return KeyeVL1_5ImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[KeyeVL1_5VideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
num_frames = kwargs.pop("num_frames", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
expected_dim=4,
name="video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, expected_dim=2, name="video grid_thw")
num_frames = self._validate_and_reshape_mm_tensor(
num_frames, expected_dim=1, name="video num frames")
return KeyeVL1_5VideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
num_frames=num_frames)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, expected_dim=2, name="video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, expected_dim=2, name="video grid_thw")
return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
num_frames=num_frames)
def _process_video_input(
self,
video_input: KeyeVL1_5VideoInputs) -> tuple[torch.Tensor, ...]:
video_type = video_input["type"]
video_grid_thw = split_thw(video_input["video_grid_thw"])
pixel_values_videos = video_input.get("pixel_values_videos", None)
video_embeds = self._process_video_embeds(video_type, video_grid_thw,
pixel_values_videos)
video_embeds = torch.concat(video_embeds, dim=0)
num_frames = video_input["num_frames"].clone().tolist()
num_patches = get_num_patches(video_grid_thw, num_frames).tolist()
patch_cu_seqlens = torch.cumsum(
torch.tensor([0] + num_patches).detach().clone(), dim=-1)
patch_cu_seqlens = torch.div(patch_cu_seqlens,
self.merge_size**2,
rounding_mode="floor")
new_video_embeds = []
for idx in range(patch_cu_seqlens.shape[0] - 1):
start = patch_cu_seqlens[idx]
end = patch_cu_seqlens[idx + 1]
new_video_embeds.append(video_embeds[start:end])
return tuple(new_video_embeds)

View File

@ -227,6 +227,7 @@ _MULTIMODAL_MODELS = {
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
"KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),