[Model] Upstream Deepseek-OCR model (#27247)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Isotr0py
2025-10-22 22:59:15 +08:00
committed by GitHub
parent 3ae082c373
commit 675aa2ec64
10 changed files with 1821 additions and 40 deletions

View File

@ -639,6 +639,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ |
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ |
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |

View File

@ -30,6 +30,7 @@ class ModelRequestData(NamedTuple):
prompts: list[str]
stop_token_ids: list[int] | None = None
lora_requests: list[LoRARequest] | None = None
sampling_params: list[SamplingParams] | None = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
@ -153,23 +154,6 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
)
# Dots-OCR
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions]
engine_args = EngineArgs(
model="rednote-hilab/dots.ocr",
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -217,6 +201,66 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
)
def run_deepseek_ocr(questions: list[str], modality: str) -> ModelRequestData:
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
assert modality == "image"
model_name = "deepseek-ai/DeepSeek-OCR"
engine_args = EngineArgs(
model=model_name,
limit_mm_per_prompt={modality: 1},
logits_processors=[NGramPerReqLogitsProcessor],
)
# deepseek-ocr use plain prompt template
prompts = [f"<image>\n{question}" for question in questions]
# The following sampling params config is taken from
# the official Deepseek-OCR inference example.
# (IMPORTANT) Use the custom logits processor and avoid skipping
# special tokens for this model for the optimal OCR performance.
sampling_params = [
SamplingParams(
temperature=0.0,
max_tokens=8192,
# ngram logit processor args
extra_args=dict(
ngram_size=30,
window_size=90,
# whitelist: <td>, </td>
whitelist_token_ids={128821, 128822},
),
skip_special_tokens=False,
)
for _ in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
sampling_params=sampling_params,
)
# Dots-OCR
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions]
engine_args = EngineArgs(
model="rednote-hilab/dots.ocr",
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Ernie4.5-VL
def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
@ -1738,9 +1782,10 @@ model_example_map = {
"bee": run_bee,
"blip-2": run_blip2,
"chameleon": run_chameleon,
"dots_ocr": run_dots_ocr,
"command_a_vision": run_command_a_vision,
"deepseek_vl_v2": run_deepseek_vl2,
"deepseek_ocr": run_deepseek_ocr,
"dots_ocr": run_dots_ocr,
"ernie45_vl": run_ernie45_vl,
"fuyu": run_fuyu,
"gemma3": run_gemma3,
@ -2003,8 +2048,12 @@ def main(args):
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
sampling_params = (
SamplingParams(
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
)
if req_data.sampling_params is None
else req_data.sampling_params
)
assert args.num_prompts > 0

View File

@ -585,6 +585,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
transformers_version_reason="HF model is not compatible.",
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
),
"DeepseekOCRForCausalLM": _HfExamplesInfo(
"deepseek-ai/DeepSeek-OCR",
),
"DotsOCRForCausalLM": _HfExamplesInfo(
"rednote-hilab/dots.ocr", trust_remote_code=True
),

View File

@ -0,0 +1,673 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from
# https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/sam_vary_sdpa.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from collections.abc import Iterable
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .clip import CLIPEncoder, CLIPVisionEmbeddings
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: type[nn.Module] = nn.LayerNorm,
act_layer: type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
""" # noqa: E501
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: nn.Parameter | None = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(
1, img_size // patch_size, img_size // patch_size, embed_dim
)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)
def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int):
dtype = abs_pos.dtype
src_size = abs_pos.size(1)
if src_size != tgt_size:
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode="bicubic",
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
return new_pos_embed
else:
return abs_pos
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.get_abs_pos(self.pos_embed, x.size(1))
for blk in self.blocks:
x = blk(x)
neck_output = self.neck(x.permute(0, 3, 1, 2))
conv2_output = self.net_2(neck_output)
conv3_output = self.net_3(conv2_output)
return conv3_output
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation
blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: type[nn.Module] = nn.LayerNorm,
act_layer: type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: tuple[int, int] | None = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
""" # noqa: E501
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = RelPosAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class RelPosAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: tuple[int, int] | None = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
""" # noqa: E501
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert input_size is not None, (
"Input size must be provided if using relative positional encoding."
)
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = add_decomposed_rel_pos(
q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
)
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
if self.use_rel_pos:
rel_h = rel_h.view(
B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)
)
rel_w = rel_w.view(
B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)
)
attn_bias = (rel_h + rel_w).view(
B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_bias
)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = (
x.view(B, self.num_heads, H, W, -1)
.permute(0, 2, 3, 1, 4)
.reshape(B, H, W, -1)
)
x = self.proj(x)
return x
def window_partition(
x: torch.Tensor, window_size: int
) -> tuple[torch.Tensor, tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
""" # noqa: E501
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor,
window_size: int,
pad_hw: tuple[int, int],
hw: tuple[int, int],
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
""" # noqa: E501
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
dtype = rel_pos.dtype
rel_pos = rel_pos.to(torch.float32)
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
).to(dtype)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(
k_size / q_size, 1.0
)
k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(
q_size / k_size, 1.0
)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: tuple[int, int],
k_size: tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
""" # noqa: E501
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
return rel_h, rel_w
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: tuple[int, int] = (16, 16),
stride: tuple[int, int] = (16, 16),
padding: tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
# TODO(Isotr0py): use vision_config to build sam model
def build_sam_vit_b():
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
)
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_encoder = ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
return image_encoder
class DeepCLIPVisionEmbeddings(CLIPVisionEmbeddings):
def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int):
# abs_pos: L, C
# tgt_size: M
# return: M, C
dim = abs_pos.size(-1)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = (
old_pos_embed.view(1, src_size, src_size, dim)
.permute(0, 3, 1, 2)
.contiguous()
)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode="bicubic",
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
return vision_pos_embed
else:
return abs_pos
def forward(
self, pixel_values: torch.Tensor, patch_embeds: torch.Tensor | None = None
) -> torch.Tensor:
batch_size = pixel_values.shape[0]
if patch_embeds is not None:
patch_embeds = patch_embeds
else:
patch_embeds = self.patch_embedding(pixel_values)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.get_abs_pos(
self.position_embedding(self.position_ids), embeddings.size(1)
)
return embeddings
class DeepCLIPVisionTransformer(nn.Module):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = DeepCLIPVisionEmbeddings(config)
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.transformer = CLIPEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
)
num_hidden_layers = config.num_hidden_layers
if len(self.transformer.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.transformer.layers)} layers."
)
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
pixel_values: torch.Tensor,
patch_embeds: torch.Tensor | None = None,
*,
select_layers: list[int] | None = None,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values, patch_embeds)
hidden_states = self.pre_layrnorm(hidden_states)
# Produces either the last layer output or all of the hidden states,
# depending on if we have select_layers or not
encoder_outputs = self.transformer(
inputs_embeds=hidden_states,
return_all_hidden_states=select_layers is not None,
)
return encoder_outputs
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -0,0 +1,594 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Deepseek-OCR model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
import torch
import torch.nn as nn
from transformers import BatchFeature, CLIPVisionConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsPP,
)
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargs,
NestedTensors,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.processors.deepseek_ocr import (
BASE_SIZE,
CROP_MODE,
IMAGE_SIZE,
DeepseekOCRProcessor,
count_tiles,
)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
)
from .deepencoder import DeepCLIPVisionTransformer, build_sam_vit_b
from .deepseek_vl2 import MlpProjector
# The image token id may be various
_IMAGE_TOKEN = "<image>"
class NoRepeatNGramLogitsProcessor:
def __init__(
self,
ngram_size: int,
window_size: int,
whitelist_token_ids: set[int] | None = None,
):
self.ngram_size = ngram_size
self.window_size = window_size
self.whitelist_token_ids = whitelist_token_ids or set()
def __call__(
self,
output_ids: list[int],
logits: torch.Tensor,
) -> torch.Tensor:
if len(output_ids) < self.ngram_size:
return logits
current_prefix = tuple(output_ids[-(self.ngram_size - 1) :])
search_start = max(0, len(output_ids) - self.window_size)
search_end = len(output_ids) - self.ngram_size + 1
banned_tokens = set()
for i in range(search_start, search_end):
ngram = tuple(output_ids[i : i + self.ngram_size])
if ngram[:-1] == current_prefix:
banned_tokens.add(ngram[-1])
banned_tokens = banned_tokens - self.whitelist_token_ids
if banned_tokens:
logits[list(banned_tokens)] = -float("inf")
return logits
class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""
def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
super().__init__(vllm_config, device, is_pin_memory)
def is_argmax_invariant(self) -> bool:
return True
def new_req_logits_processor(
self,
params: SamplingParams,
) -> RequestLogitsProcessor | None:
ngram_size = params.extra_args and params.extra_args.get("ngram_size")
window_size = params.extra_args and params.extra_args.get("window_size", 100)
whitelist_token_ids = params.extra_args and params.extra_args.get(
"whitelist_token_ids", None
)
if ngram_size is None:
return None
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
)
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(
"`window_size` has to be a strictly positive integer, "
f"got {window_size}."
)
if whitelist_token_ids is not None and not isinstance(
whitelist_token_ids, Iterable
):
raise ValueError(
"`whitelist_token_ids` has to be a set of integers, "
f"got {whitelist_token_ids}."
)
else:
whitelist_token_ids = (
set(whitelist_token_ids) if whitelist_token_ids else None
)
return NoRepeatNGramLogitsProcessor(
ngram_size=ngram_size,
window_size=window_size,
whitelist_token_ids=whitelist_token_ids,
)
class DeepseekOCRProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
self, *, image_width: int, image_height: int, cropping: bool = True
) -> int:
image_size = IMAGE_SIZE
base_size = BASE_SIZE
patch_size = 16
downsample_ratio = 4
if CROP_MODE:
if image_width <= 640 and image_height <= 640:
crop_ratio = [1, 1]
else:
# find the closest aspect ratio to the target
crop_ratio = count_tiles(
image_width, image_height, image_size=IMAGE_SIZE
)
num_width_tiles, num_height_tiles = crop_ratio
else:
num_width_tiles = num_height_tiles = 1
h = w = math.ceil((base_size // patch_size) / downsample_ratio)
h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)
global_views_tokens = h * (w + 1)
if num_width_tiles > 1 or num_height_tiles > 1:
local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
else:
local_views_tokens = 0
return global_views_tokens + local_views_tokens + 1
def get_image_size_with_most_features(self) -> ImageSize:
if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
return ImageSize(width=1024 * 2, height=1024 * 2)
return ImageSize(width=640 * 2, height=640 * 2)
class DeepseekOCRDummyInputsBuilder(BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
max_image_size = self.info.get_image_size_with_most_features()
return {
"image": self._get_dummy_images(
width=max_image_size.width,
height=max_image_size.height,
num_images=num_images,
)
}
class DeepseekOCRMultiModalProcessor(
BaseMultiModalProcessor[DeepseekOCRProcessingInfo]
):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
mm_kwargs,
)
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(
prompt, add_special_tokens=True, return_tensors="pt"
)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = hf_processor.image_token_id
assert isinstance(image_token_id, int)
def get_replacement_deepseek_vl2(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=size.width,
image_height=size.height,
cropping=CROP_MODE,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_deepseek_vl2,
)
]
# TODO(Isotr0py): Check if we still need this workaround for
# deepseek-ocr processor.
# def _cached_apply_hf_processor(
# self,
# prompt: str | list[int],
# mm_data_items: MultiModalDataItems,
# hf_processor_mm_kwargs: Mapping[str, object],
# tokenization_kwargs: Mapping[str, object],
# mm_uuids: MultiModalUUIDDict | None = None,
# ) -> tuple[list[int], MultiModalKwargs, bool]:
# # The processor logic is different for len(images) <= 2 vs > 2
# # Since the processing cache assumes that the processor output is
# # invariant of how many images are passed per prompt, we only
# # perform caching for the most common case
# if mm_data_items.get_count("image", strict=False) > 2:
# # This code path corresponds to the cache being disabled
# return self._apply_hf_processor_main(
# prompt=prompt,
# mm_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# enable_hf_prompt_update=True,
# )
# return super()._cached_apply_hf_processor(
# prompt=prompt,
# mm_data_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# )
@MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor,
info=DeepseekOCRProcessingInfo,
dummy_inputs=DeepseekOCRDummyInputsBuilder,
)
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# map prefix for language backbone
"model.embed_tokens.": "language_model.model.embed_tokens.",
"model.layers.": "language_model.model.layers.",
"model.norm.": "language_model.model.norm.",
"lm_head.": "language_model.lm_head.",
# remove "model." prefix for other components
"model.": "",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.vision_config = config.vision_config
self.projector_config = config.projector_config
self.text_config = config.text_config
model_config = vllm_config.model_config
tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
self.sam_model = build_sam_vit_b()
clip_vision_config = CLIPVisionConfig(
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
num_hidden_layers=24,
image_size=224,
patch_size=14,
projection_dim=512,
layer_norm_eps=1e-5,
)
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
)
self.projector = MlpProjector(self.projector_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# special token for image token sequence format
n_embed = self.projector_config.n_embed
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
# This is a typo in original implementation
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
)
if self.text_config.topk_method == "noaux_tc":
architectures = ["DeepseekV3ForCausalLM"]
elif not self.text_config.use_mla:
architectures = ["DeepseekForCausalLM"]
else:
architectures = ["DeepseekV2ForCausalLM"]
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=architectures,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
if pixel_values is None or torch.sum(pixel_values).item() == 0:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}"
)
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image crop. Got type: {type(images_crop)}"
)
return [pixel_values, images_crop, images_spatial_crop]
raise AssertionError("This line should be unreachable.")
def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
global_features_1 = self.sam_model(image_tensor)
global_features_2 = self.vision_model(image_tensor, global_features_1)
features = torch.cat(
(
global_features_2[:, 1:],
global_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
features = self.projector(features)
_, hw, dim = features.shape
side = int(hw**0.5)
features = features.view(side, side, dim)
newline = self.image_newline[None, None, :].expand(side, 1, dim)
features = torch.cat([features, newline], dim=1)
return features.view(-1, dim)
def _encode_local_features(
self, patches: torch.Tensor, crop_shape: torch.Tensor
) -> torch.Tensor | None:
if torch.sum(patches).item() == 0:
return None
local_features_1 = self.sam_model(patches)
local_features_2 = self.vision_model(patches, local_features_1)
features = torch.cat(
(
local_features_2[:, 1:],
local_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
features = self.projector(features)
_, hw, dim = features.shape
patch_side = int(hw**0.5)
width_tiles = int(crop_shape[0].item())
height_tiles = int(crop_shape[1].item())
features = (
features.view(height_tiles, width_tiles, patch_side, patch_side, dim)
.permute(0, 2, 1, 3, 4)
.reshape(height_tiles * patch_side, width_tiles * patch_side, dim)
)
newline = self.image_newline[None, None, :].expand(
height_tiles * patch_side, 1, dim
)
features = torch.cat([features, newline], dim=1)
return features.view(-1, dim)
def _pixel_values_to_embedding(
self,
pixel_values: torch.Tensor,
images_crop: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
images_in_this_batch = []
for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16)
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
global_features = self._encode_global_features(image_ori)
local_features = self._encode_local_features(patches, crop_shape)
if local_features is not None:
combined = torch.cat(
[local_features, global_features, self.view_seperator[None, :]],
dim=0,
)
else:
combined = torch.cat(
[global_features, self.view_seperator[None, :]], dim=0
)
images_in_this_batch.append(combined)
return images_in_this_batch
def _process_image_input(self, image_input) -> torch.Tensor:
pixel_values = image_input[0].to(torch.bfloat16)
images_crop = image_input[1]
images_spatial_crop = image_input[2].to(dtype=torch.long)
vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
)
return vision_features
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object
) -> MultiModalEmbeddings | None:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
):
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return autoloaded_weights

View File

@ -101,9 +101,10 @@ class MlpProjector(nn.Module):
super().__init__()
self.cfg = cfg
self.projector_type = cfg.projector_type
assert not cfg.token_pooling, "Token pooling is not supported currently."
if cfg.projector_type == "downsample_mlp_gelu":
if self.projector_type == "downsample_mlp_gelu":
mlp_depth = cfg.depth
mlp_ratio = cfg.mlp_ratio
modules = [
@ -120,7 +121,8 @@ class MlpProjector(nn.Module):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
modules = nn.Sequential(*modules)
elif self.projector_type == "linear":
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
else:
raise NotImplementedError(
f"Unsupported projector type: {cfg.projector_type}"
@ -130,24 +132,25 @@ class MlpProjector(nn.Module):
def forward(self, x):
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.cfg.downsample_ratio:
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(
x,
kernel_size=self.cfg.downsample_ratio,
stride=self.cfg.downsample_ratio,
padding=0,
) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
if self.projector_type == "downsample_mlp_gelu":
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.cfg.downsample_ratio:
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(
x,
kernel_size=self.cfg.downsample_ratio,
stride=self.cfg.downsample_ratio,
padding=0,
) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
return self.layers(x)

View File

@ -258,6 +258,7 @@ _MULTIMODAL_MODELS = {
"Cohere2VisionForConditionalGeneration",
),
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
"DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
"Ernie4_5_VLMoeForConditionalGeneration": (
"ernie45_vl",

View File

@ -34,6 +34,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
"clip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
"deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja",
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
"minicpmv": _get_minicpmv_chat_template_fallback,
"paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja",

View File

@ -0,0 +1,14 @@
{%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%}
{%- set messages = messages[1:] -%}
{%- else -%}
{% set system_message = '' -%}
{%- endif -%}
{{ bos_token + system_message }}
{%- for message in messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif -%}
{{ message['content'] }}
{%- endfor -%}

View File

@ -0,0 +1,442 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/image_process.py
import math
import torch
import torchvision.transforms as T
from PIL import Image, ImageOps
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
from transformers.processing_utils import ProcessorMixin
# TODO(Isotr0py): change modes for variants
# see: https://github.com/deepseek-ai/DeepSeek-OCR/blob/8cf003d38821fa1b19c73da3bd1b0dc262ea8136/DeepSeek-OCR-master/DeepSeek-OCR-vllm/config.py#L1-L6
# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
BASE_SIZE = 1024
IMAGE_SIZE = 640
CROP_MODE = True
# TODO(Isotr0py): Expose as mm_kwargs
MIN_CROPS = 2
MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def calculate_aspect_ratios(
min_num: int = MIN_CROPS, max_num: int = MAX_CROPS
) -> list[tuple[int, int]]:
target_ratios: set[tuple[int, int]] = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
sorted_target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
return sorted_target_ratios
def count_tiles(
orig_width,
orig_height,
min_num=MIN_CROPS,
max_num=MAX_CROPS,
image_size=640,
use_thumbnail=False,
):
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = calculate_aspect_ratios(min_num, max_num)
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
return target_aspect_ratio
def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = calculate_aspect_ratios(min_num, max_num)
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class ImageTransform:
def __init__(
self,
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
std: tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
):
self.mean = mean
self.std = std
self.normalize = normalize
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class DeepseekOCRProcessor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast,
patch_size: int = 16,
downsample_ratio: int = 4,
image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<▁pad▁>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.image_size = IMAGE_SIZE
self.base_size = BASE_SIZE
self.patch_size = 16
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
self.downsample_ratio = 4
self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
self.tokenizer = tokenizer
self.tokenizer.padding_side = "left" # must set thispadding side with make a difference in batch inference # noqa: E501
# add the pad_token as special token to use 'tokenizer.pad_token'
# and 'tokenizer.pad_token_id'
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": pad_token})
# add image token
self.image_token_id = self.tokenizer.vocab.get(image_token)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: list[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str,
images: list[Image.Image],
crop_mode: bool = CROP_MODE,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
crop_mode (bool): if True, then crop the image;
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert prompt is not None and images is not None, (
"prompt and images must be used at the same time."
)
sft_format = prompt
(
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
_,
) = self.tokenize_with_images(
conversation=sft_format,
images=images,
bos=True,
eos=True,
cropping=crop_mode,
)
prepare = BatchFeature(
data=dict(
input_ids=input_ids,
pixel_values=pixel_values,
images_crop=images_crop,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
num_image_tokens=num_image_tokens,
),
tensor_type="pt",
)
return prepare
def __call__(
self,
*,
prompt: str,
images: list[Image.Image],
crop_mode: bool = CROP_MODE,
**kwargs,
):
prepare = self.process_one(
prompt=prompt,
images=images,
crop_mode=crop_mode,
)
return prepare
def tokenize_with_images(
self,
conversation: str,
images: list[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
):
"""Tokenize text with <image> tags."""
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = []
num_image_tokens = []
tokenized_str = []
for text_sep, image in zip(text_splits, images):
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
image_shapes.append(image.size)
images_crop_raw = []
if image.size[0] <= 640 and image.size[1] <= 640:
crop_ratio = [1, 1]
elif cropping:
images_crop_raw, crop_ratio = dynamic_preprocess(
image, image_size=IMAGE_SIZE
)
else:
crop_ratio = [1, 1]
if self.image_size <= 640 and not cropping:
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(
image,
(self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view))
num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1:
for cropped_image in images_crop_raw:
images_crop_list.append(self.image_transform(cropped_image))
num_queries = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio
)
num_queries_base = math.ceil(
(self.base_size // self.patch_size) / self.downsample_ratio
)
tokenized_image = (
[self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
local_row = [self.image_token_id] * (num_queries * num_width_tiles + 1)
tokenized_image += local_row * (num_queries * num_height_tiles)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(images_seq_mask), (
f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} "
f"is not equal to images_seq_mask's length {len(images_seq_mask)}."
)
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, "
f"input_ids' length {len(masked_tokenized_str)}, "
f"images_seq_mask's length {len(images_seq_mask)}, are not equal."
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
)
AutoProcessor.register("DeepseekOCRProcessor", DeepseekOCRProcessor)