[Model] Upstream Deepseek-OCR model (#27247)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@ -639,6 +639,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
|
||||
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ |
|
||||
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ |
|
||||
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
|
||||
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||
|
||||
@ -30,6 +30,7 @@ class ModelRequestData(NamedTuple):
|
||||
prompts: list[str]
|
||||
stop_token_ids: list[int] | None = None
|
||||
lora_requests: list[LoRARequest] | None = None
|
||||
sampling_params: list[SamplingParams] | None = None
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
@ -153,23 +154,6 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Dots-OCR
|
||||
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="rednote-hilab/dots.ocr",
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
@ -217,6 +201,66 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_deepseek_ocr(questions: list[str], modality: str) -> ModelRequestData:
|
||||
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
|
||||
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "deepseek-ai/DeepSeek-OCR"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
logits_processors=[NGramPerReqLogitsProcessor],
|
||||
)
|
||||
|
||||
# deepseek-ocr use plain prompt template
|
||||
prompts = [f"<image>\n{question}" for question in questions]
|
||||
|
||||
# The following sampling params config is taken from
|
||||
# the official Deepseek-OCR inference example.
|
||||
# (IMPORTANT) Use the custom logits processor and avoid skipping
|
||||
# special tokens for this model for the optimal OCR performance.
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=8192,
|
||||
# ngram logit processor args
|
||||
extra_args=dict(
|
||||
ngram_size=30,
|
||||
window_size=90,
|
||||
# whitelist: <td>, </td>
|
||||
whitelist_token_ids={128821, 128822},
|
||||
),
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
for _ in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
# Dots-OCR
|
||||
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="rednote-hilab/dots.ocr",
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Ernie4.5-VL
|
||||
def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
|
||||
@ -1738,9 +1782,10 @@ model_example_map = {
|
||||
"bee": run_bee,
|
||||
"blip-2": run_blip2,
|
||||
"chameleon": run_chameleon,
|
||||
"dots_ocr": run_dots_ocr,
|
||||
"command_a_vision": run_command_a_vision,
|
||||
"deepseek_vl_v2": run_deepseek_vl2,
|
||||
"deepseek_ocr": run_deepseek_ocr,
|
||||
"dots_ocr": run_dots_ocr,
|
||||
"ernie45_vl": run_ernie45_vl,
|
||||
"fuyu": run_fuyu,
|
||||
"gemma3": run_gemma3,
|
||||
@ -2003,8 +2048,12 @@ def main(args):
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
sampling_params = (
|
||||
SamplingParams(
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
if req_data.sampling_params is None
|
||||
else req_data.sampling_params
|
||||
)
|
||||
|
||||
assert args.num_prompts > 0
|
||||
|
||||
@ -585,6 +585,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
transformers_version_reason="HF model is not compatible.",
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
|
||||
),
|
||||
"DeepseekOCRForCausalLM": _HfExamplesInfo(
|
||||
"deepseek-ai/DeepSeek-OCR",
|
||||
),
|
||||
"DotsOCRForCausalLM": _HfExamplesInfo(
|
||||
"rednote-hilab/dots.ocr", trust_remote_code=True
|
||||
),
|
||||
|
||||
673
vllm/model_executor/models/deepencoder.py
Normal file
673
vllm/model_executor/models/deepencoder.py
Normal file
@ -0,0 +1,673 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# adapted from
|
||||
# https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/sam_vary_sdpa.py
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionConfig
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .clip import CLIPEncoder, CLIPVisionEmbeddings
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
mlp_dim: int,
|
||||
act: type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
||||
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lin2(self.act(self.lin1(x)))
|
||||
|
||||
|
||||
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
||||
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
|
||||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
class ImageEncoderViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
patch_size (int): Patch size.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
depth (int): Depth of ViT.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_abs_pos (bool): If True, use absolute positional embeddings.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks.
|
||||
global_attn_indexes (list): Indexes for blocks using global attention.
|
||||
""" # noqa: E501
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.pos_embed: nn.Parameter | None = None
|
||||
if use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size.
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(
|
||||
1, img_size // patch_size, img_size // patch_size, embed_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
window_size=window_size if i not in global_attn_indexes else 0,
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
nn.Conv2d(
|
||||
out_chans,
|
||||
out_chans,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
)
|
||||
|
||||
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.net_3 = nn.Conv2d(
|
||||
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
|
||||
)
|
||||
|
||||
def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int):
|
||||
dtype = abs_pos.dtype
|
||||
|
||||
src_size = abs_pos.size(1)
|
||||
|
||||
if src_size != tgt_size:
|
||||
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
|
||||
old_pos_embed = old_pos_embed.to(torch.float32)
|
||||
new_pos_embed = F.interpolate(
|
||||
old_pos_embed,
|
||||
size=(tgt_size, tgt_size),
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
align_corners=False,
|
||||
).to(dtype)
|
||||
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
|
||||
return new_pos_embed
|
||||
else:
|
||||
return abs_pos
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.get_abs_pos(self.pos_embed, x.size(1))
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
neck_output = self.neck(x.permute(0, 3, 1, 2))
|
||||
conv2_output = self.net_2(neck_output)
|
||||
conv3_output = self.net_3(conv2_output)
|
||||
|
||||
return conv3_output
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation
|
||||
blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
input_size: tuple[int, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks. If it equals 0, then
|
||||
use global attention.
|
||||
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
""" # noqa: E501
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = RelPosAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||
)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLPBlock(
|
||||
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
|
||||
)
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
|
||||
x = self.attn(x)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class RelPosAttention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
input_size: tuple[int, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
""" # noqa: E501
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert input_size is not None, (
|
||||
"Input size must be provided if using relative positional encoding."
|
||||
)
|
||||
# initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = (
|
||||
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||
|
||||
rel_h, rel_w = None, None
|
||||
if self.use_rel_pos:
|
||||
rel_h, rel_w = add_decomposed_rel_pos(
|
||||
q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
|
||||
)
|
||||
|
||||
q = q.view(B, self.num_heads, H * W, -1)
|
||||
k = k.view(B, self.num_heads, H * W, -1)
|
||||
v = v.view(B, self.num_heads, H * W, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
rel_h = rel_h.view(
|
||||
B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)
|
||||
)
|
||||
rel_w = rel_w.view(
|
||||
B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)
|
||||
)
|
||||
attn_bias = (rel_h + rel_w).view(
|
||||
B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
|
||||
)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_bias
|
||||
)
|
||||
else:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x = (
|
||||
x.view(B, self.num_heads, H, W, -1)
|
||||
.permute(0, 2, 3, 1, 4)
|
||||
.reshape(B, H, W, -1)
|
||||
)
|
||||
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(
|
||||
x: torch.Tensor, window_size: int
|
||||
) -> tuple[torch.Tensor, tuple[int, int]]:
|
||||
"""
|
||||
Partition into non-overlapping windows with padding if needed.
|
||||
Args:
|
||||
x (tensor): input tokens with [B, H, W, C].
|
||||
window_size (int): window size.
|
||||
|
||||
Returns:
|
||||
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||
(Hp, Wp): padded height and width before partition
|
||||
""" # noqa: E501
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = (
|
||||
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(
|
||||
windows: torch.Tensor,
|
||||
window_size: int,
|
||||
pad_hw: tuple[int, int],
|
||||
hw: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
pad_hw (Tuple): padded height and width (Hp, Wp).
|
||||
hw (Tuple): original height and width (H, W) before padding.
|
||||
|
||||
Returns:
|
||||
x: unpartitioned sequences with [B, H, W, C].
|
||||
""" # noqa: E501
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(
|
||||
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
||||
)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
rel_pos (Tensor): relative position embeddings (L, C).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
dtype = rel_pos.dtype
|
||||
rel_pos = rel_pos.to(torch.float32)
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
).to(dtype)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(
|
||||
k_size / q_size, 1.0
|
||||
)
|
||||
k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(
|
||||
q_size / k_size, 1.0
|
||||
)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: tuple[int, int],
|
||||
k_size: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
||||
Args:
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
attn (Tensor): attention map with added relative positional embeddings.
|
||||
""" # noqa: E501
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
rel_h = rel_h.unsqueeze(-1)
|
||||
rel_w = rel_w.unsqueeze(-2)
|
||||
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
|
||||
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
|
||||
|
||||
return rel_h, rel_w
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: tuple[int, int] = (16, 16),
|
||||
stride: tuple[int, int] = (16, 16),
|
||||
padding: tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
padding (Tuple): padding size of the projection layer.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
# B C H W -> B H W C
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return x
|
||||
|
||||
|
||||
# TODO(Isotr0py): use vision_config to build sam model
|
||||
def build_sam_vit_b():
|
||||
return _build_sam(
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||
)
|
||||
|
||||
|
||||
def _build_sam(
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
):
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
image_encoder = ImageEncoderViT(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
)
|
||||
return image_encoder
|
||||
|
||||
|
||||
class DeepCLIPVisionEmbeddings(CLIPVisionEmbeddings):
|
||||
def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int):
|
||||
# abs_pos: L, C
|
||||
# tgt_size: M
|
||||
# return: M, C
|
||||
dim = abs_pos.size(-1)
|
||||
abs_pos_new = abs_pos.squeeze(0)
|
||||
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
|
||||
|
||||
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
|
||||
tgt_size = int(math.sqrt(tgt_size))
|
||||
dtype = abs_pos.dtype
|
||||
|
||||
if src_size != tgt_size:
|
||||
old_pos_embed = (
|
||||
old_pos_embed.view(1, src_size, src_size, dim)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
old_pos_embed = old_pos_embed.to(torch.float32)
|
||||
new_pos_embed = F.interpolate(
|
||||
old_pos_embed,
|
||||
size=(tgt_size, tgt_size),
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
align_corners=False,
|
||||
).to(dtype)
|
||||
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
|
||||
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
|
||||
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
|
||||
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
|
||||
return vision_pos_embed
|
||||
else:
|
||||
return abs_pos
|
||||
|
||||
def forward(
|
||||
self, pixel_values: torch.Tensor, patch_embeds: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
if patch_embeds is not None:
|
||||
patch_embeds = patch_embeds
|
||||
else:
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.get_abs_pos(
|
||||
self.position_embedding(self.position_ids), embeddings.size(1)
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
class DeepCLIPVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = DeepCLIPVisionEmbeddings(config)
|
||||
|
||||
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
|
||||
# the original transformers code and name of the model weights.
|
||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.transformer = CLIPEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if len(self.transformer.layers) > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"The original encoder only has {num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.transformer.layers)} layers."
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
patch_embeds: torch.Tensor | None = None,
|
||||
*,
|
||||
select_layers: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embeddings(pixel_values, patch_embeds)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
# Produces either the last layer output or all of the hidden states,
|
||||
# depending on if we have select_layers or not
|
||||
encoder_outputs = self.transformer(
|
||||
inputs_embeds=hidden_states,
|
||||
return_all_hidden_states=select_layers is not None,
|
||||
)
|
||||
return encoder_outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
594
vllm/model_executor/models/deepseek_ocr.py
Normal file
594
vllm/model_executor/models/deepseek_ocr.py
Normal file
@ -0,0 +1,594 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only Deepseek-OCR model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, CLIPVisionConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargs,
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
|
||||
from vllm.transformers_utils.processors.deepseek_ocr import (
|
||||
BASE_SIZE,
|
||||
CROP_MODE,
|
||||
IMAGE_SIZE,
|
||||
DeepseekOCRProcessor,
|
||||
count_tiles,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
AdapterLogitsProcessor,
|
||||
RequestLogitsProcessor,
|
||||
)
|
||||
|
||||
from .deepencoder import DeepCLIPVisionTransformer, build_sam_vit_b
|
||||
from .deepseek_vl2 import MlpProjector
|
||||
|
||||
# The image token id may be various
|
||||
_IMAGE_TOKEN = "<image>"
|
||||
|
||||
|
||||
class NoRepeatNGramLogitsProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
ngram_size: int,
|
||||
window_size: int,
|
||||
whitelist_token_ids: set[int] | None = None,
|
||||
):
|
||||
self.ngram_size = ngram_size
|
||||
self.window_size = window_size
|
||||
self.whitelist_token_ids = whitelist_token_ids or set()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
output_ids: list[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if len(output_ids) < self.ngram_size:
|
||||
return logits
|
||||
|
||||
current_prefix = tuple(output_ids[-(self.ngram_size - 1) :])
|
||||
|
||||
search_start = max(0, len(output_ids) - self.window_size)
|
||||
search_end = len(output_ids) - self.ngram_size + 1
|
||||
|
||||
banned_tokens = set()
|
||||
for i in range(search_start, search_end):
|
||||
ngram = tuple(output_ids[i : i + self.ngram_size])
|
||||
if ngram[:-1] == current_prefix:
|
||||
banned_tokens.add(ngram[-1])
|
||||
|
||||
banned_tokens = banned_tokens - self.whitelist_token_ids
|
||||
|
||||
if banned_tokens:
|
||||
logits[list(banned_tokens)] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of overriding the wrapper class `__init__()` in order to utilize
|
||||
info about the device type"""
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
super().__init__(vllm_config, device, is_pin_memory)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return True
|
||||
|
||||
def new_req_logits_processor(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> RequestLogitsProcessor | None:
|
||||
ngram_size = params.extra_args and params.extra_args.get("ngram_size")
|
||||
window_size = params.extra_args and params.extra_args.get("window_size", 100)
|
||||
whitelist_token_ids = params.extra_args and params.extra_args.get(
|
||||
"whitelist_token_ids", None
|
||||
)
|
||||
if ngram_size is None:
|
||||
return None
|
||||
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||
raise ValueError(
|
||||
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
|
||||
)
|
||||
if not isinstance(window_size, int) or window_size <= 0:
|
||||
raise ValueError(
|
||||
"`window_size` has to be a strictly positive integer, "
|
||||
f"got {window_size}."
|
||||
)
|
||||
if whitelist_token_ids is not None and not isinstance(
|
||||
whitelist_token_ids, Iterable
|
||||
):
|
||||
raise ValueError(
|
||||
"`whitelist_token_ids` has to be a set of integers, "
|
||||
f"got {whitelist_token_ids}."
|
||||
)
|
||||
else:
|
||||
whitelist_token_ids = (
|
||||
set(whitelist_token_ids) if whitelist_token_ids else None
|
||||
)
|
||||
return NoRepeatNGramLogitsProcessor(
|
||||
ngram_size=ngram_size,
|
||||
window_size=window_size,
|
||||
whitelist_token_ids=whitelist_token_ids,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekOCRProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(DeepseekVLV2Config)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self, *, image_width: int, image_height: int, cropping: bool = True
|
||||
) -> int:
|
||||
image_size = IMAGE_SIZE
|
||||
base_size = BASE_SIZE
|
||||
patch_size = 16
|
||||
downsample_ratio = 4
|
||||
|
||||
if CROP_MODE:
|
||||
if image_width <= 640 and image_height <= 640:
|
||||
crop_ratio = [1, 1]
|
||||
else:
|
||||
# find the closest aspect ratio to the target
|
||||
crop_ratio = count_tiles(
|
||||
image_width, image_height, image_size=IMAGE_SIZE
|
||||
)
|
||||
|
||||
num_width_tiles, num_height_tiles = crop_ratio
|
||||
else:
|
||||
num_width_tiles = num_height_tiles = 1
|
||||
|
||||
h = w = math.ceil((base_size // patch_size) / downsample_ratio)
|
||||
|
||||
h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)
|
||||
|
||||
global_views_tokens = h * (w + 1)
|
||||
if num_width_tiles > 1 or num_height_tiles > 1:
|
||||
local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
|
||||
else:
|
||||
local_views_tokens = 0
|
||||
|
||||
return global_views_tokens + local_views_tokens + 1
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
|
||||
return ImageSize(width=1024 * 2, height=1024 * 2)
|
||||
return ImageSize(width=640 * 2, height=640 * 2)
|
||||
|
||||
|
||||
class DeepseekOCRDummyInputsBuilder(BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
|
||||
return image_token * num_images
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
max_image_size = self.info.get_image_size_with_most_features()
|
||||
|
||||
return {
|
||||
"image": self._get_dummy_images(
|
||||
width=max_image_size.width,
|
||||
height=max_image_size.height,
|
||||
num_images=num_images,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class DeepseekOCRMultiModalProcessor(
|
||||
BaseMultiModalProcessor[DeepseekOCRProcessingInfo]
|
||||
):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if mm_data:
|
||||
processed_outputs = self.info.ctx.call_hf_processor(
|
||||
self.info.get_hf_processor(**mm_kwargs),
|
||||
dict(prompt=prompt, **mm_data),
|
||||
mm_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
processed_outputs = tokenizer(
|
||||
prompt, add_special_tokens=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
||||
images_crop=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
image_token_id = hf_processor.image_token_id
|
||||
assert isinstance(image_token_id, int)
|
||||
|
||||
def get_replacement_deepseek_vl2(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
size = images.get_image_size(item_idx)
|
||||
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=size.width,
|
||||
image_height=size.height,
|
||||
cropping=CROP_MODE,
|
||||
)
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement_deepseek_vl2,
|
||||
)
|
||||
]
|
||||
|
||||
# TODO(Isotr0py): Check if we still need this workaround for
|
||||
# deepseek-ocr processor.
|
||||
# def _cached_apply_hf_processor(
|
||||
# self,
|
||||
# prompt: str | list[int],
|
||||
# mm_data_items: MultiModalDataItems,
|
||||
# hf_processor_mm_kwargs: Mapping[str, object],
|
||||
# tokenization_kwargs: Mapping[str, object],
|
||||
# mm_uuids: MultiModalUUIDDict | None = None,
|
||||
# ) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
# # The processor logic is different for len(images) <= 2 vs > 2
|
||||
# # Since the processing cache assumes that the processor output is
|
||||
# # invariant of how many images are passed per prompt, we only
|
||||
# # perform caching for the most common case
|
||||
# if mm_data_items.get_count("image", strict=False) > 2:
|
||||
# # This code path corresponds to the cache being disabled
|
||||
# return self._apply_hf_processor_main(
|
||||
# prompt=prompt,
|
||||
# mm_items=mm_data_items,
|
||||
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
# enable_hf_prompt_update=True,
|
||||
# )
|
||||
|
||||
# return super()._cached_apply_hf_processor(
|
||||
# prompt=prompt,
|
||||
# mm_data_items=mm_data_items,
|
||||
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
# )
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
DeepseekOCRMultiModalProcessor,
|
||||
info=DeepseekOCRProcessingInfo,
|
||||
dummy_inputs=DeepseekOCRDummyInputsBuilder,
|
||||
)
|
||||
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# map prefix for language backbone
|
||||
"model.embed_tokens.": "language_model.model.embed_tokens.",
|
||||
"model.layers.": "language_model.model.layers.",
|
||||
"model.norm.": "language_model.model.norm.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
# remove "model." prefix for other components
|
||||
"model.": "",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.vision_config = config.vision_config
|
||||
self.projector_config = config.projector_config
|
||||
self.text_config = config.text_config
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
|
||||
|
||||
self.sam_model = build_sam_vit_b()
|
||||
clip_vision_config = CLIPVisionConfig(
|
||||
hidden_size=1024,
|
||||
intermediate_size=4096,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=24,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
projection_dim=512,
|
||||
layer_norm_eps=1e-5,
|
||||
)
|
||||
self.vision_model = DeepCLIPVisionTransformer(
|
||||
config=clip_vision_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.projector = MlpProjector(self.projector_config)
|
||||
self.tile_tag = config.tile_tag
|
||||
self.global_view_pos = config.global_view_pos
|
||||
|
||||
# special token for image token sequence format
|
||||
n_embed = self.projector_config.n_embed
|
||||
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
|
||||
if self.tile_tag == "2D":
|
||||
# <|view_separator|>, <|\n|>
|
||||
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
|
||||
# This is a typo in original implementation
|
||||
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
|
||||
)
|
||||
|
||||
if self.text_config.topk_method == "noaux_tc":
|
||||
architectures = ["DeepseekV3ForCausalLM"]
|
||||
elif not self.text_config.use_mla:
|
||||
architectures = ["DeepseekForCausalLM"]
|
||||
else:
|
||||
architectures = ["DeepseekV2ForCausalLM"]
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=self.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
architectures=architectures,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
|
||||
images_crop = kwargs.pop("images_crop", None)
|
||||
|
||||
if pixel_values is None or torch.sum(pixel_values).item() == 0:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image sizes. "
|
||||
f"Got type: {type(images_spatial_crop)}"
|
||||
)
|
||||
|
||||
if not isinstance(images_crop, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of image crop. Got type: {type(images_crop)}"
|
||||
)
|
||||
|
||||
return [pixel_values, images_crop, images_spatial_crop]
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
global_features_1 = self.sam_model(image_tensor)
|
||||
global_features_2 = self.vision_model(image_tensor, global_features_1)
|
||||
features = torch.cat(
|
||||
(
|
||||
global_features_2[:, 1:],
|
||||
global_features_1.flatten(2).permute(0, 2, 1),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
features = self.projector(features)
|
||||
|
||||
_, hw, dim = features.shape
|
||||
side = int(hw**0.5)
|
||||
|
||||
features = features.view(side, side, dim)
|
||||
newline = self.image_newline[None, None, :].expand(side, 1, dim)
|
||||
features = torch.cat([features, newline], dim=1)
|
||||
return features.view(-1, dim)
|
||||
|
||||
def _encode_local_features(
|
||||
self, patches: torch.Tensor, crop_shape: torch.Tensor
|
||||
) -> torch.Tensor | None:
|
||||
if torch.sum(patches).item() == 0:
|
||||
return None
|
||||
|
||||
local_features_1 = self.sam_model(patches)
|
||||
local_features_2 = self.vision_model(patches, local_features_1)
|
||||
features = torch.cat(
|
||||
(
|
||||
local_features_2[:, 1:],
|
||||
local_features_1.flatten(2).permute(0, 2, 1),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
features = self.projector(features)
|
||||
|
||||
_, hw, dim = features.shape
|
||||
patch_side = int(hw**0.5)
|
||||
|
||||
width_tiles = int(crop_shape[0].item())
|
||||
height_tiles = int(crop_shape[1].item())
|
||||
|
||||
features = (
|
||||
features.view(height_tiles, width_tiles, patch_side, patch_side, dim)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
.reshape(height_tiles * patch_side, width_tiles * patch_side, dim)
|
||||
)
|
||||
newline = self.image_newline[None, None, :].expand(
|
||||
height_tiles * patch_side, 1, dim
|
||||
)
|
||||
features = torch.cat([features, newline], dim=1)
|
||||
|
||||
return features.view(-1, dim)
|
||||
|
||||
def _pixel_values_to_embedding(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
images_crop: torch.Tensor,
|
||||
images_spatial_crop: torch.Tensor,
|
||||
) -> NestedTensors:
|
||||
images_in_this_batch = []
|
||||
|
||||
for jdx in range(images_spatial_crop.size(0)):
|
||||
patches = images_crop[jdx][0].to(torch.bfloat16)
|
||||
image_ori = pixel_values[jdx]
|
||||
crop_shape = images_spatial_crop[jdx][0]
|
||||
|
||||
global_features = self._encode_global_features(image_ori)
|
||||
local_features = self._encode_local_features(patches, crop_shape)
|
||||
|
||||
if local_features is not None:
|
||||
combined = torch.cat(
|
||||
[local_features, global_features, self.view_seperator[None, :]],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
combined = torch.cat(
|
||||
[global_features, self.view_seperator[None, :]], dim=0
|
||||
)
|
||||
|
||||
images_in_this_batch.append(combined)
|
||||
|
||||
return images_in_this_batch
|
||||
|
||||
def _process_image_input(self, image_input) -> torch.Tensor:
|
||||
pixel_values = image_input[0].to(torch.bfloat16)
|
||||
images_crop = image_input[1]
|
||||
images_spatial_crop = image_input[2].to(dtype=torch.long)
|
||||
|
||||
vision_features = self._pixel_values_to_embedding(
|
||||
pixel_values=pixel_values,
|
||||
images_crop=images_crop,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
)
|
||||
|
||||
return vision_features
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object
|
||||
) -> MultiModalEmbeddings | None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
return autoloaded_weights
|
||||
@ -101,9 +101,10 @@ class MlpProjector(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = cfg
|
||||
self.projector_type = cfg.projector_type
|
||||
assert not cfg.token_pooling, "Token pooling is not supported currently."
|
||||
|
||||
if cfg.projector_type == "downsample_mlp_gelu":
|
||||
if self.projector_type == "downsample_mlp_gelu":
|
||||
mlp_depth = cfg.depth
|
||||
mlp_ratio = cfg.mlp_ratio
|
||||
modules = [
|
||||
@ -120,7 +121,8 @@ class MlpProjector(nn.Module):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
||||
modules = nn.Sequential(*modules)
|
||||
|
||||
elif self.projector_type == "linear":
|
||||
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported projector type: {cfg.projector_type}"
|
||||
@ -130,24 +132,25 @@ class MlpProjector(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
bs, hw, input_dim = x.shape
|
||||
h = w = int((hw) ** 0.5)
|
||||
"""compute padding"""
|
||||
if h % self.cfg.downsample_ratio:
|
||||
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
|
||||
else:
|
||||
pad = 0
|
||||
x = x.reshape(bs, h, w, input_dim)
|
||||
if pad > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
||||
"""4 to 1 concat"""
|
||||
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
||||
x = F.unfold(
|
||||
x,
|
||||
kernel_size=self.cfg.downsample_ratio,
|
||||
stride=self.cfg.downsample_ratio,
|
||||
padding=0,
|
||||
) # B, C*4, HW // 4
|
||||
x = x.permute(0, 2, 1)
|
||||
if self.projector_type == "downsample_mlp_gelu":
|
||||
h = w = int((hw) ** 0.5)
|
||||
"""compute padding"""
|
||||
if h % self.cfg.downsample_ratio:
|
||||
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
|
||||
else:
|
||||
pad = 0
|
||||
x = x.reshape(bs, h, w, input_dim)
|
||||
if pad > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
||||
"""4 to 1 concat"""
|
||||
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
||||
x = F.unfold(
|
||||
x,
|
||||
kernel_size=self.cfg.downsample_ratio,
|
||||
stride=self.cfg.downsample_ratio,
|
||||
padding=0,
|
||||
) # B, C*4, HW // 4
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
@ -258,6 +258,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Cohere2VisionForConditionalGeneration",
|
||||
),
|
||||
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
|
||||
"DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
|
||||
"DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": (
|
||||
"ernie45_vl",
|
||||
|
||||
@ -34,6 +34,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
|
||||
"clip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
|
||||
"deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja",
|
||||
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
|
||||
"minicpmv": _get_minicpmv_chat_template_fallback,
|
||||
"paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set system_message = messages[0]['content'] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{% set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{ bos_token + system_message }}
|
||||
{%- for message in messages -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||
{%- endif -%}
|
||||
{{ message['content'] }}
|
||||
{%- endfor -%}
|
||||
442
vllm/transformers_utils/processors/deepseek_ocr.py
Normal file
442
vllm/transformers_utils/processors/deepseek_ocr.py
Normal file
@ -0,0 +1,442 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# adapted from https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/image_process.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image, ImageOps
|
||||
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
# TODO(Isotr0py): change modes for variants
|
||||
# see: https://github.com/deepseek-ai/DeepSeek-OCR/blob/8cf003d38821fa1b19c73da3bd1b0dc262ea8136/DeepSeek-OCR-master/DeepSeek-OCR-vllm/config.py#L1-L6
|
||||
# Tiny: base_size = 512, image_size = 512, crop_mode = False
|
||||
# Small: base_size = 640, image_size = 640, crop_mode = False
|
||||
# Base: base_size = 1024, image_size = 1024, crop_mode = False
|
||||
# Large: base_size = 1280, image_size = 1280, crop_mode = False
|
||||
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
|
||||
BASE_SIZE = 1024
|
||||
IMAGE_SIZE = 640
|
||||
CROP_MODE = True
|
||||
|
||||
# TODO(Isotr0py): Expose as mm_kwargs
|
||||
MIN_CROPS = 2
|
||||
MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
|
||||
|
||||
|
||||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
|
||||
def calculate_aspect_ratios(
|
||||
min_num: int = MIN_CROPS, max_num: int = MAX_CROPS
|
||||
) -> list[tuple[int, int]]:
|
||||
target_ratios: set[tuple[int, int]] = set(
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
)
|
||||
sorted_target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
return sorted_target_ratios
|
||||
|
||||
|
||||
def count_tiles(
|
||||
orig_width,
|
||||
orig_height,
|
||||
min_num=MIN_CROPS,
|
||||
max_num=MAX_CROPS,
|
||||
image_size=640,
|
||||
use_thumbnail=False,
|
||||
):
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = calculate_aspect_ratios(min_num, max_num)
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
||||
)
|
||||
|
||||
return target_aspect_ratio
|
||||
|
||||
|
||||
def dynamic_preprocess(
|
||||
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
|
||||
):
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = calculate_aspect_ratios(min_num, max_num)
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images, target_aspect_ratio
|
||||
|
||||
|
||||
class ImageTransform:
|
||||
def __init__(
|
||||
self,
|
||||
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||
std: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||
normalize: bool = True,
|
||||
):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.normalize = normalize
|
||||
|
||||
transform_pipelines = [T.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_pipelines.append(T.Normalize(mean, std))
|
||||
|
||||
self.transform = T.Compose(transform_pipelines)
|
||||
|
||||
def __call__(self, pil_img: Image.Image):
|
||||
x = self.transform(pil_img)
|
||||
return x
|
||||
|
||||
|
||||
class DeepseekOCRProcessor(ProcessorMixin):
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
attributes = ["tokenizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
patch_size: int = 16,
|
||||
downsample_ratio: int = 4,
|
||||
image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||
image_std: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||
normalize: bool = True,
|
||||
image_token: str = "<image>",
|
||||
pad_token: str = "<|▁pad▁|>",
|
||||
add_special_token: bool = False,
|
||||
sft_format: str = "deepseek",
|
||||
mask_prompt: bool = True,
|
||||
ignore_id: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_size = IMAGE_SIZE
|
||||
self.base_size = BASE_SIZE
|
||||
self.patch_size = 16
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.normalize = normalize
|
||||
self.downsample_ratio = 4
|
||||
|
||||
self.image_transform = ImageTransform(
|
||||
mean=image_mean, std=image_std, normalize=normalize
|
||||
)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.padding_side = "left" # must set this,padding side with make a difference in batch inference # noqa: E501
|
||||
|
||||
# add the pad_token as special token to use 'tokenizer.pad_token'
|
||||
# and 'tokenizer.pad_token_id'
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.add_special_tokens({"pad_token": pad_token})
|
||||
|
||||
# add image token
|
||||
self.image_token_id = self.tokenizer.vocab.get(image_token)
|
||||
self.image_token = image_token
|
||||
self.pad_token = pad_token
|
||||
self.add_special_token = add_special_token
|
||||
self.sft_format = sft_format
|
||||
self.mask_prompt = mask_prompt
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
super().__init__(
|
||||
tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def bos_id(self):
|
||||
return self.tokenizer.bos_token_id
|
||||
|
||||
@property
|
||||
def eos_id(self):
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
def encode(self, text: str, bos: bool = True, eos: bool = False):
|
||||
t = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
return t
|
||||
|
||||
def decode(self, t: list[int], **kwargs) -> str:
|
||||
return self.tokenizer.decode(t, **kwargs)
|
||||
|
||||
def process_one(
|
||||
self,
|
||||
prompt: str,
|
||||
images: list[Image.Image],
|
||||
crop_mode: bool = CROP_MODE,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
images (List[ImageType]): the list of images;
|
||||
crop_mode (bool): if True, then crop the image;
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- target_ids (torch.LongTensor): [N + image tokens]
|
||||
- pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
assert prompt is not None and images is not None, (
|
||||
"prompt and images must be used at the same time."
|
||||
)
|
||||
|
||||
sft_format = prompt
|
||||
|
||||
(
|
||||
input_ids,
|
||||
pixel_values,
|
||||
images_crop,
|
||||
images_seq_mask,
|
||||
images_spatial_crop,
|
||||
num_image_tokens,
|
||||
_,
|
||||
) = self.tokenize_with_images(
|
||||
conversation=sft_format,
|
||||
images=images,
|
||||
bos=True,
|
||||
eos=True,
|
||||
cropping=crop_mode,
|
||||
)
|
||||
|
||||
prepare = BatchFeature(
|
||||
data=dict(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
images_crop=images_crop,
|
||||
images_seq_mask=images_seq_mask,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
num_image_tokens=num_image_tokens,
|
||||
),
|
||||
tensor_type="pt",
|
||||
)
|
||||
return prepare
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
images: list[Image.Image],
|
||||
crop_mode: bool = CROP_MODE,
|
||||
**kwargs,
|
||||
):
|
||||
prepare = self.process_one(
|
||||
prompt=prompt,
|
||||
images=images,
|
||||
crop_mode=crop_mode,
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def tokenize_with_images(
|
||||
self,
|
||||
conversation: str,
|
||||
images: list[Image.Image],
|
||||
bos: bool = True,
|
||||
eos: bool = True,
|
||||
cropping: bool = True,
|
||||
):
|
||||
"""Tokenize text with <image> tags."""
|
||||
|
||||
assert conversation.count(self.image_token) == len(images)
|
||||
text_splits = conversation.split(self.image_token)
|
||||
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
image_shapes = []
|
||||
num_image_tokens = []
|
||||
tokenized_str = []
|
||||
for text_sep, image in zip(text_splits, images):
|
||||
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
|
||||
tokenized_str += tokenized_sep
|
||||
images_seq_mask += [False] * len(tokenized_sep)
|
||||
|
||||
image_shapes.append(image.size)
|
||||
|
||||
images_crop_raw = []
|
||||
if image.size[0] <= 640 and image.size[1] <= 640:
|
||||
crop_ratio = [1, 1]
|
||||
elif cropping:
|
||||
images_crop_raw, crop_ratio = dynamic_preprocess(
|
||||
image, image_size=IMAGE_SIZE
|
||||
)
|
||||
else:
|
||||
crop_ratio = [1, 1]
|
||||
|
||||
if self.image_size <= 640 and not cropping:
|
||||
image = image.resize((self.image_size, self.image_size))
|
||||
|
||||
global_view = ImageOps.pad(
|
||||
image,
|
||||
(self.base_size, self.base_size),
|
||||
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
||||
)
|
||||
images_list.append(self.image_transform(global_view))
|
||||
|
||||
num_width_tiles, num_height_tiles = crop_ratio
|
||||
images_spatial_crop.append([num_width_tiles, num_height_tiles])
|
||||
|
||||
if num_width_tiles > 1 or num_height_tiles > 1:
|
||||
for cropped_image in images_crop_raw:
|
||||
images_crop_list.append(self.image_transform(cropped_image))
|
||||
|
||||
num_queries = math.ceil(
|
||||
(self.image_size // self.patch_size) / self.downsample_ratio
|
||||
)
|
||||
num_queries_base = math.ceil(
|
||||
(self.base_size // self.patch_size) / self.downsample_ratio
|
||||
)
|
||||
|
||||
tokenized_image = (
|
||||
[self.image_token_id] * num_queries_base + [self.image_token_id]
|
||||
) * num_queries_base
|
||||
tokenized_image += [self.image_token_id]
|
||||
if num_width_tiles > 1 or num_height_tiles > 1:
|
||||
local_row = [self.image_token_id] * (num_queries * num_width_tiles + 1)
|
||||
tokenized_image += local_row * (num_queries * num_height_tiles)
|
||||
tokenized_str += tokenized_image
|
||||
images_seq_mask += [True] * len(tokenized_image)
|
||||
num_image_tokens.append(len(tokenized_image))
|
||||
|
||||
"""process the last text split"""
|
||||
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
|
||||
tokenized_str += tokenized_sep
|
||||
images_seq_mask += [False] * len(tokenized_sep)
|
||||
|
||||
"""add the bos and eos tokens"""
|
||||
if bos:
|
||||
tokenized_str = [self.bos_id] + tokenized_str
|
||||
images_seq_mask = [False] + images_seq_mask
|
||||
if eos:
|
||||
tokenized_str = tokenized_str + [self.eos_id]
|
||||
images_seq_mask = images_seq_mask + [False]
|
||||
|
||||
assert len(tokenized_str) == len(images_seq_mask), (
|
||||
f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} "
|
||||
f"is not equal to images_seq_mask's length {len(images_seq_mask)}."
|
||||
)
|
||||
|
||||
masked_tokenized_str = []
|
||||
for token_index in tokenized_str:
|
||||
if token_index != self.image_token_id:
|
||||
masked_tokenized_str.append(token_index)
|
||||
else:
|
||||
masked_tokenized_str.append(self.ignore_id)
|
||||
|
||||
assert (
|
||||
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
|
||||
), (
|
||||
f"tokenized_str's length {len(tokenized_str)}, "
|
||||
f"input_ids' length {len(masked_tokenized_str)}, "
|
||||
f"images_seq_mask's length {len(images_seq_mask)}, are not equal."
|
||||
)
|
||||
|
||||
input_ids = torch.LongTensor(tokenized_str)
|
||||
target_ids = torch.LongTensor(masked_tokenized_str)
|
||||
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
||||
|
||||
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
|
||||
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
|
||||
self.ignore_id
|
||||
)
|
||||
input_ids[input_ids < 0] = self.pad_id
|
||||
|
||||
# Remove the ending eos token
|
||||
assert input_ids[-1] == self.eos_id
|
||||
input_ids = input_ids[:-1]
|
||||
target_ids = target_ids[:-1]
|
||||
images_seq_mask = images_seq_mask[:-1]
|
||||
|
||||
if len(images_list) == 0:
|
||||
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
|
||||
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
|
||||
images_crop = torch.zeros(
|
||||
(1, 3, self.image_size, self.image_size)
|
||||
).unsqueeze(0)
|
||||
else:
|
||||
pixel_values = torch.stack(images_list, dim=0)
|
||||
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
||||
if images_crop_list:
|
||||
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
|
||||
else:
|
||||
images_crop = torch.zeros(
|
||||
(1, 3, self.image_size, self.image_size)
|
||||
).unsqueeze(0)
|
||||
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
|
||||
return (
|
||||
input_ids,
|
||||
pixel_values,
|
||||
images_crop,
|
||||
images_seq_mask,
|
||||
images_spatial_crop,
|
||||
num_image_tokens,
|
||||
image_shapes,
|
||||
)
|
||||
|
||||
|
||||
AutoProcessor.register("DeepseekOCRProcessor", DeepseekOCRProcessor)
|
||||
Reference in New Issue
Block a user