[Misc][qwen2_5_vl][torch.compile] Enable supports_torch_compile on generic nn.Module and demonstrate speedup on Qwen Vision model (#23207)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
This commit is contained in:
Lucas Kabela
2025-10-28 15:36:43 -07:00
committed by GitHub
parent 4fe5895361
commit 94666612a9
6 changed files with 334 additions and 97 deletions

View File

@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.compilation.counter import compilation_counter
from vllm.config.compilation import CompilationMode
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
"""Test that Qwen2.5-VL vision submodules are compiled.
This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged
for compilation by checking that num_models_seen increases by at least 3.
"""
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
# NOTE: Qwen2.5-VL has 35 models in total - the LLM backend
# Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks
# (one for each layer) - in the future, we should fix vLLM compilation
# logic to handle this case and only compile the Vision submodules once
# and reuse the compiled code for all layers
# See https://github.com/vllm-project/vllm/issues/27590
compilation_counter.expect(num_models_seen=35),
vllm_runner(
"Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=2048,
gpu_memory_utilization=0.7,
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
) as _,
):
pass

View File

@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
`to_list` in xformers attn, or `.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
latencies by ~7% (see qwen2_5_vl for example usage)
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
"""
import einops
import torch
from vllm.utils.torch_utils import direct_register_custom_op
def xformers_attn_seqlens_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer
def xformers_attn_seqlens_wrapper_fake(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="xformers_attn_seqlens_wrapper",
op_func=xformers_attn_seqlens_wrapper,
fake_impl=xformers_attn_seqlens_wrapper_fake,
)
def vit_xformers_attn_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(),
max_seqlen_k=max_seqlen.item(),
dropout_p=0.0,
causal=False,
)
context_layer = einops.rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
return context_layer
def flash_attn_maxseqlen_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="flash_attn_maxseqlen_wrapper",
op_func=flash_attn_maxseqlen_wrapper,
fake_impl=flash_attn_maxseqlen_wrapper_fake,
)
def vit_flash_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
)

View File

@ -18,7 +18,12 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
from vllm.config import (
CompilationMode,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
@ -74,6 +79,21 @@ def support_torch_compile(
) -> Callable[[_T], _T]: ...
@overload
def support_torch_compile(
*,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...
@overload
def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...
@overload
def support_torch_compile(cls: _T) -> _T: ...
@ -82,6 +102,7 @@ def support_torch_compile(
cls: _T | None = None,
*,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[_T], _T] | _T:
"""
@ -135,6 +156,11 @@ def support_torch_compile(
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo do not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
"""
def cls_decorator_helper(cls: _T) -> _T:
@ -172,7 +198,9 @@ def support_torch_compile(
raise ValueError(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
return _support_torch_compile(
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
)
if cls is not None:
# use `support_torch_compile` as a decorator without arguments
@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None:
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
) -> _T:
"""
@ -230,8 +259,22 @@ def _support_torch_compile(
setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
def __init__(
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
):
if vllm_config is None:
vllm_config = get_current_vllm_config()
# NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch
# it
sig = inspect.signature(old_init)
if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config
if "prefix" in sig.parameters:
kwargs["prefix"] = prefix
old_init(self, **kwargs)
self.vllm_config = vllm_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
@ -344,6 +387,15 @@ def _support_torch_compile(
"Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}."
)
if mark_unbacked_dims:
for k, dims in mark_unbacked_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.decorators.mark_unbacked(arg, dims)
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
logger.debug("Start compiling function %s", self.original_code_object)

View File

@ -684,6 +684,8 @@ class CompilationConfig:
from vllm.compilation.backends import VllmBackend
# TODO[@lucaskabela]: See if we can forward prefix
# https://github.com/vllm-project/vllm/issues/27045
return VllmBackend(vllm_config)
def post_init_cudagraph_sizes(self) -> None:

View File

@ -45,6 +45,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
@ -759,7 +760,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
assert grid_thw.ndim == 2
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
with set_forward_context(None, self.vllm_config):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
@ -779,7 +781,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
assert grid_thw.ndim == 2
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
with set_forward_context(None, self.vllm_config):
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
@ -839,6 +842,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
thinker_config: Qwen2_5OmniThinkerConfig = (
vllm_config.model_config.hf_config.thinker_config
)

View File

@ -31,10 +31,10 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
@ -47,9 +47,15 @@ from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
@ -392,8 +398,8 @@ class Qwen2_5_VisionAttention(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@ -402,7 +408,7 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
if rotary_pos_emb is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
@ -410,31 +416,18 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = self.flash_attn_varlen_func(
context_layer = vit_flash_attn_wrapper(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA,
self.use_upstream_fa,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
@ -443,34 +436,31 @@ class Qwen2_5_VisionAttention(nn.Module):
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer = einops.rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer)
return output
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
"cu_seqlens": 0,
"rotary_pos_emb": 0,
"seqlens": 0,
},
mark_unbacked_dims={"seqlens": 0},
)
class Qwen2_5_VisionBlock(nn.Module):
def __init__(
self,
@ -515,8 +505,8 @@ class Qwen2_5_VisionBlock(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@ -530,6 +520,11 @@ class Qwen2_5_VisionBlock(nn.Module):
return x
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
}
)
class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__(
self,
@ -556,6 +551,11 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
return x
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
}
)
class Qwen2_5_VisionPatchMerger(nn.Module):
def __init__(
self,
@ -665,13 +665,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.spatial_merge_size = vision_config.spatial_merge_size
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
self.spatial_merge_unit = self.spatial_merge_size**2
# TODO[@lucaskabela]: Investigate fixing this usage
# see https://github.com/vllm-project/vllm/issues/27044
# DO NOT MOVE THIS IMPORT
from vllm.compilation.backends import set_model_tag
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=in_channels,
hidden_size=self.hidden_size,
)
with set_model_tag("Qwen2_5_VisionPatchEmbed"):
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=in_channels,
hidden_size=self.hidden_size,
)
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
@ -701,32 +706,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList(
[
Qwen2_5_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
)
for layer_idx in range(depth)
]
)
self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
with set_model_tag("Qwen2_5_VisionBlock"):
self.blocks = nn.ModuleList(
[
Qwen2_5_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
)
for layer_idx in range(depth)
]
)
with set_model_tag("Qwen2_5_VisionPatchMerger"):
self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
@property
def dtype(self) -> torch.dtype:
@ -827,15 +835,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen, seqlens = (
torch.zeros(1, device=cu_seqlens.device),
torch.zeros(1, device=cu_seqlens.device),
)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
@staticmethod
@ -1233,6 +1244,7 @@ class Qwen2_5_VLForConditionalGeneration(
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.config = config
self.vllm_config = vllm_config
self.multimodal_config = multimodal_config
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
@ -1248,7 +1260,7 @@ class Qwen2_5_VLForConditionalGeneration(
else None
)
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"),
@ -1336,13 +1348,13 @@ class Qwen2_5_VLForConditionalGeneration(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"]
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@ -1396,12 +1408,18 @@ class Qwen2_5_VLForConditionalGeneration(
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d",
)
else:
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw_list
)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size