|
|
|
|
@ -5,22 +5,26 @@ from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
from transformers import PaliGemmaConfig
|
|
|
|
|
from transformers import BatchFeature, PaliGemmaConfig
|
|
|
|
|
|
|
|
|
|
from vllm.config import VllmConfig
|
|
|
|
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
|
|
|
|
InputContext, token_inputs)
|
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
|
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
|
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
|
|
|
from vllm.multimodal.inputs import NestedTensors
|
|
|
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|
|
|
|
MultiModalInputs, MultiModalKwargs,
|
|
|
|
|
NestedTensors)
|
|
|
|
|
from vllm.multimodal.parse import MultiModalDataItems
|
|
|
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
|
|
|
BaseProcessingInfo, PromptIndexTargets,
|
|
|
|
|
PromptInsertion, PromptReplacement,
|
|
|
|
|
PromptUpdateDetails)
|
|
|
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
|
|
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
|
|
|
|
|
|
|
|
|
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
|
|
|
|
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|
|
|
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
|
|
|
|
from .interfaces import SupportsMultiModal, SupportsPP
|
|
|
|
|
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
|
|
|
|
|
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
|
|
|
|
maybe_prefix, merge_multimodal_embeddings)
|
|
|
|
|
|
|
|
|
|
@ -46,79 +50,6 @@ PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
|
|
|
|
|
PaliGemmaImageEmbeddingInputs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_max_paligemma_image_tokens(ctx: InputContext):
|
|
|
|
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
|
|
|
|
vision_config = hf_config.vision_config
|
|
|
|
|
|
|
|
|
|
return get_max_siglip_image_tokens(vision_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
|
|
|
|
|
mm_counts: Mapping[str, int]):
|
|
|
|
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
|
|
|
|
vision_config = hf_config.vision_config
|
|
|
|
|
num_images = mm_counts["image"]
|
|
|
|
|
|
|
|
|
|
seq_data, ranges = dummy_seq_data_for_siglip(
|
|
|
|
|
vision_config,
|
|
|
|
|
seq_len,
|
|
|
|
|
num_images,
|
|
|
|
|
image_token_id=hf_config.image_token_index,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
|
|
|
|
return DummyData(seq_data, mm_data, ranges)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def input_processor_for_paligemma(ctx: InputContext,
|
|
|
|
|
inputs: DecoderOnlyInputs):
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
The correct prompt format needs to be:
|
|
|
|
|
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
|
|
|
|
|
|
|
|
|
|
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
|
|
|
|
|
""" # noqa
|
|
|
|
|
|
|
|
|
|
multi_modal_data = inputs.get("multi_modal_data")
|
|
|
|
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
model_config = ctx.model_config
|
|
|
|
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
|
|
|
|
|
|
|
|
|
tokenizer = cached_tokenizer_from_config(model_config)
|
|
|
|
|
image_feature_size = hf_config.text_config.num_image_tokens
|
|
|
|
|
image_token_str = tokenizer.decode(hf_config.image_token_index)
|
|
|
|
|
bos_token = tokenizer.decode(hf_config.bos_token_id)
|
|
|
|
|
image_token_str_pad = image_token_str * image_feature_size
|
|
|
|
|
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
|
|
|
|
|
|
|
|
|
|
orig_prompt = inputs.get("prompt")
|
|
|
|
|
orig_prompt_ids = inputs.get("prompt_token_ids")
|
|
|
|
|
|
|
|
|
|
if orig_prompt is not None and image_token_str in orig_prompt:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"The image token '%s' was detected in the prompt and "
|
|
|
|
|
"will be removed. Please follow the proper prompt format"
|
|
|
|
|
" documented on HuggingFace.", image_token_str)
|
|
|
|
|
orig_prompt = orig_prompt.replace(image_token_str, "")
|
|
|
|
|
orig_prompt_ids.remove(hf_config.image_token_index)
|
|
|
|
|
|
|
|
|
|
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
|
|
|
|
|
|
|
|
|
|
# The PaliGemma 2 tokenizer does not include a starting BOS token
|
|
|
|
|
if orig_prompt_ids[0] != hf_config.bos_token_id:
|
|
|
|
|
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
|
|
|
|
|
|
|
|
|
|
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
|
|
|
|
|
|
|
|
|
|
# NOTE: Create a defensive copy of the original inputs
|
|
|
|
|
return token_inputs(prompt_token_ids=new_token_ids,
|
|
|
|
|
prompt=new_prompt,
|
|
|
|
|
multi_modal_data=multi_modal_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PaliGemmaMultiModalProjector(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
|
|
|
|
@ -131,12 +62,140 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
|
|
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
|
|
|
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
|
|
|
|
class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
|
|
|
|
|
|
|
|
|
def get_hf_config(self):
|
|
|
|
|
return self.ctx.get_hf_config(PaliGemmaConfig)
|
|
|
|
|
|
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
|
|
|
return {"image": 1}
|
|
|
|
|
|
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
|
|
|
self,
|
|
|
|
|
seq_len: int,
|
|
|
|
|
mm_counts: Mapping[str, int],
|
|
|
|
|
) -> Mapping[str, int]:
|
|
|
|
|
return {"image": self.get_num_image_tokens()}
|
|
|
|
|
|
|
|
|
|
def get_num_image_tokens(self) -> int:
|
|
|
|
|
hf_config = self.get_hf_config()
|
|
|
|
|
vision_config = hf_config.vision_config
|
|
|
|
|
return get_max_siglip_image_tokens(vision_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PaliGemmaDummyInputsBuilder(
|
|
|
|
|
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
|
|
|
|
|
|
|
|
|
|
def get_dummy_processor_inputs(
|
|
|
|
|
self,
|
|
|
|
|
seq_len: int,
|
|
|
|
|
mm_counts: Mapping[str, int],
|
|
|
|
|
) -> ProcessorInputs:
|
|
|
|
|
hf_config = self.info.get_hf_config()
|
|
|
|
|
vision_config = hf_config.vision_config
|
|
|
|
|
max_image_size = vision_config.image_size
|
|
|
|
|
|
|
|
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
|
|
|
|
|
|
mm_data = {
|
|
|
|
|
"image":
|
|
|
|
|
self._get_dummy_images(width=max_image_size,
|
|
|
|
|
height=max_image_size,
|
|
|
|
|
num_images=num_images)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ProcessorInputs(
|
|
|
|
|
prompt_text="",
|
|
|
|
|
mm_data=mm_data,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PaliGemmaMultiModalProcessor(
|
|
|
|
|
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
|
|
|
|
|
|
|
|
|
|
def _call_hf_processor(
|
|
|
|
|
self,
|
|
|
|
|
prompt: str,
|
|
|
|
|
mm_data: Mapping[str, object],
|
|
|
|
|
mm_kwargs: Mapping[str, object],
|
|
|
|
|
) -> BatchFeature:
|
|
|
|
|
tokenizer = self.info.get_tokenizer()
|
|
|
|
|
if not mm_data:
|
|
|
|
|
prompt_ids = tokenizer.encode(prompt)
|
|
|
|
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
|
|
|
|
|
|
|
|
|
return super()._call_hf_processor(
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
mm_data=mm_data,
|
|
|
|
|
mm_kwargs=mm_kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_mm_fields_config(
|
|
|
|
|
self,
|
|
|
|
|
hf_inputs: BatchFeature,
|
|
|
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
|
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
|
|
|
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
|
|
|
|
|
|
|
|
|
def _get_prompt_updates(
|
|
|
|
|
self,
|
|
|
|
|
mm_items: MultiModalDataItems,
|
|
|
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
|
|
|
out_mm_kwargs: MultiModalKwargs,
|
|
|
|
|
) -> list[PromptReplacement]:
|
|
|
|
|
hf_config = self.info.get_hf_config()
|
|
|
|
|
image_token_id = hf_config.image_token_index
|
|
|
|
|
|
|
|
|
|
tokenizer = self.info.get_tokenizer()
|
|
|
|
|
num_image_tokens = self.info.get_num_image_tokens()
|
|
|
|
|
image_tokens = [image_token_id] * num_image_tokens
|
|
|
|
|
|
|
|
|
|
bos_token_id = tokenizer.bos_token_id
|
|
|
|
|
assert isinstance(bos_token_id, int)
|
|
|
|
|
|
|
|
|
|
# Paligemma 1 and 2 have different tokenizer.add_bos_token
|
|
|
|
|
# Insert <image>*n + <bos> after <bos> for Paligemma 1
|
|
|
|
|
# Insert <image>*n + <bos> for Paligemma 2
|
|
|
|
|
return [
|
|
|
|
|
PromptInsertion(
|
|
|
|
|
modality="image",
|
|
|
|
|
target=PromptIndexTargets.prefix(
|
|
|
|
|
[bos_token_id] if tokenizer.add_bos_token else []),
|
|
|
|
|
insertion=PromptUpdateDetails(
|
|
|
|
|
full=image_tokens + [bos_token_id],
|
|
|
|
|
features=image_tokens,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def apply(
|
|
|
|
|
self,
|
|
|
|
|
prompt: Union[str, list[int]],
|
|
|
|
|
mm_data: MultiModalDataDict,
|
|
|
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
|
|
|
) -> MultiModalInputs:
|
|
|
|
|
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
|
|
|
|
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
|
|
|
|
|
|
|
|
|
tokenizer = self.info.get_tokenizer()
|
|
|
|
|
newline_prompt = "\n"
|
|
|
|
|
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
|
|
|
|
|
# Force to add newline at the end of prompt for paligemma's format
|
|
|
|
|
# This step can NOT be replacemented by current PromptUpdate methods
|
|
|
|
|
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
|
|
|
|
|
prompt_token_ids.append(newline_token_id)
|
|
|
|
|
mm_inputs["prompt_token_ids"] = prompt_token_ids
|
|
|
|
|
mm_inputs["prompt"] += newline_prompt
|
|
|
|
|
|
|
|
|
|
return mm_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
|
|
|
PaliGemmaMultiModalProcessor,
|
|
|
|
|
info=PaliGemmaProcessingInfo,
|
|
|
|
|
dummy_inputs=PaliGemmaDummyInputsBuilder)
|
|
|
|
|
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
|
|
|
SupportsPP, SupportsV0Only):
|
|
|
|
|
SupportsPP):
|
|
|
|
|
packed_modules_mapping = {
|
|
|
|
|
"qkv_proj": [
|
|
|
|
|
"q_proj",
|
|
|
|
|
|