[Misc][qwen2_5_vl][torch.compile] Enable supports_torch_compile on generic nn.Module and demonstrate speedup on Qwen Vision model (#23207)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
This commit is contained in:
36
tests/compile/test_multimodal_compile.py
Normal file
36
tests/compile/test_multimodal_compile.py
Normal file
@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config.compilation import CompilationMode
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
|
||||
"""Test that Qwen2.5-VL vision submodules are compiled.
|
||||
|
||||
This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed,
|
||||
Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged
|
||||
for compilation by checking that num_models_seen increases by at least 3.
|
||||
"""
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
# NOTE: Qwen2.5-VL has 35 models in total - the LLM backend
|
||||
# Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks
|
||||
# (one for each layer) - in the future, we should fix vLLM compilation
|
||||
# logic to handle this case and only compile the Vision submodules once
|
||||
# and reuse the compiled code for all layers
|
||||
# See https://github.com/vllm-project/vllm/issues/27590
|
||||
compilation_counter.expect(num_models_seen=35),
|
||||
vllm_runner(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
max_model_len=2048,
|
||||
gpu_memory_utilization=0.7,
|
||||
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
125
vllm/attention/ops/vit_attn_wrappers.py
Normal file
125
vllm/attention/ops/vit_attn_wrappers.py
Normal file
@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains ops for ViT attention to be compatible with torch.compile
|
||||
as there are operations here not supported by torch.compile (for instance,
|
||||
`to_list` in xformers attn, or `.item()` in flash attention)
|
||||
|
||||
Using these ops and wrapping vision blocks with `torch.compile` can speed up
|
||||
throughput in vision models by ~5% relative on H100, and improve token
|
||||
latencies by ~7% (see qwen2_5_vl for example usage)
|
||||
|
||||
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
|
||||
"""
|
||||
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def xformers_attn_seqlens_wrapper(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
|
||||
)
|
||||
context_layer = xops.memory_efficient_attention_forward(
|
||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
||||
)
|
||||
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||
return context_layer
|
||||
|
||||
|
||||
def xformers_attn_seqlens_wrapper_fake(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
b, s, h, d = q.shape
|
||||
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="xformers_attn_seqlens_wrapper",
|
||||
op_func=xformers_attn_seqlens_wrapper,
|
||||
fake_impl=xformers_attn_seqlens_wrapper_fake,
|
||||
)
|
||||
|
||||
|
||||
def vit_xformers_attn_wrapper(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
|
||||
|
||||
|
||||
def flash_attn_maxseqlen_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
use_upstream_fa: bool,
|
||||
) -> torch.Tensor:
|
||||
if is_rocm_aiter:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
output = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen.item(),
|
||||
max_seqlen_k=max_seqlen.item(),
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
)
|
||||
context_layer = einops.rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
return context_layer
|
||||
|
||||
|
||||
def flash_attn_maxseqlen_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
use_upstream_fa: bool,
|
||||
) -> torch.Tensor:
|
||||
b, s, h, d = q.shape
|
||||
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flash_attn_maxseqlen_wrapper",
|
||||
op_func=flash_attn_maxseqlen_wrapper,
|
||||
fake_impl=flash_attn_maxseqlen_wrapper_fake,
|
||||
)
|
||||
|
||||
|
||||
def vit_flash_attn_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
use_upstream_fa: bool,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
|
||||
)
|
||||
@ -18,7 +18,12 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@ -74,6 +79,21 @@ def support_torch_compile(
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: _T) -> _T: ...
|
||||
|
||||
@ -82,6 +102,7 @@ def support_torch_compile(
|
||||
cls: _T | None = None,
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> Callable[[_T], _T] | _T:
|
||||
"""
|
||||
@ -135,6 +156,11 @@ def support_torch_compile(
|
||||
returns a boolean value indicating whether to compile the model or not.
|
||||
This is useful if you want to compile the model only when certain
|
||||
conditions are met.
|
||||
|
||||
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
|
||||
dim to be decorated with `mark_unbacked`. This is useful if we would like to
|
||||
enforce that dynamo do not specialize on 0/1 values in the case of dummy input
|
||||
such as for vision model compilation
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
@ -172,7 +198,9 @@ def support_torch_compile(
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}"
|
||||
)
|
||||
return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
|
||||
return _support_torch_compile(
|
||||
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
|
||||
)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None:
|
||||
def _support_torch_compile(
|
||||
cls: _T,
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> _T:
|
||||
"""
|
||||
@ -230,8 +259,22 @@ def _support_torch_compile(
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
|
||||
):
|
||||
if vllm_config is None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# NOTE: to support multimodal models (such as encoder),
|
||||
# we may not have vllm_config so we may need to patch
|
||||
# it
|
||||
sig = inspect.signature(old_init)
|
||||
if "vllm_config" in sig.parameters:
|
||||
kwargs["vllm_config"] = vllm_config
|
||||
if "prefix" in sig.parameters:
|
||||
kwargs["prefix"] = prefix
|
||||
old_init(self, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||
@ -344,6 +387,15 @@ def _support_torch_compile(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
logger.debug("Start compiling function %s", self.original_code_object)
|
||||
|
||||
@ -684,6 +684,8 @@ class CompilationConfig:
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
# TODO[@lucaskabela]: See if we can forward prefix
|
||||
# https://github.com/vllm-project/vllm/issues/27045
|
||||
return VllmBackend(vllm_config)
|
||||
|
||||
def post_init_cudagraph_sizes(self) -> None:
|
||||
|
||||
@ -45,6 +45,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
@ -759,7 +760,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
@ -779,7 +781,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
@ -839,6 +842,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
thinker_config: Qwen2_5OmniThinkerConfig = (
|
||||
vllm_config.model_config.hf_config.thinker_config
|
||||
)
|
||||
|
||||
@ -31,10 +31,10 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import lru_cache, partial
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
@ -47,9 +47,15 @@ from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_xformers_attn_wrapper,
|
||||
)
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -392,8 +398,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
seqlens: list[int] | None = None, # Only used for xFormers
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
seqlens: torch.Tensor, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
@ -402,7 +408,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
||||
q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
@ -410,31 +416,18 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = self.flash_attn_varlen_func(
|
||||
context_layer = vit_flash_attn_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == _Backend.ROCM_AITER_FA,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
@ -443,34 +436,31 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = rearrange(
|
||||
context_layer = einops.rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
||||
)
|
||||
|
||||
context_layer = xops.memory_efficient_attention_forward(
|
||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
||||
)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"x": 0,
|
||||
"cu_seqlens": 0,
|
||||
"rotary_pos_emb": 0,
|
||||
"seqlens": 0,
|
||||
},
|
||||
mark_unbacked_dims={"seqlens": 0},
|
||||
)
|
||||
class Qwen2_5_VisionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -515,8 +505,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
seqlens: list[int] | None = None, # Only used for xFormers
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
seqlens: torch.Tensor, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
x_attn = self.attn(
|
||||
self.norm1(x),
|
||||
@ -530,6 +520,11 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"x": 0,
|
||||
}
|
||||
)
|
||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -556,6 +551,11 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"x": 0,
|
||||
}
|
||||
)
|
||||
class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -665,13 +665,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||
# TODO[@lucaskabela]: Investigate fixing this usage
|
||||
# see https://github.com/vllm-project/vllm/issues/27044
|
||||
# DO NOT MOVE THIS IMPORT
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_channels=in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
with set_model_tag("Qwen2_5_VisionPatchEmbed"):
|
||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_channels=in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
@ -701,32 +706,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
with set_model_tag("Qwen2_5_VisionBlock"):
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
with set_model_tag("Qwen2_5_VisionPatchMerger"):
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@ -827,15 +835,18 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> tuple[int | None, list[int] | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
max_seqlen, seqlens = (
|
||||
torch.zeros(1, device=cu_seqlens.device),
|
||||
torch.zeros(1, device=cu_seqlens.device),
|
||||
)
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
|
||||
@staticmethod
|
||||
@ -1233,6 +1244,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
@ -1248,7 +1260,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
config.vision_config,
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
@ -1336,13 +1348,13 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"]
|
||||
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
@ -1396,12 +1408,18 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"]
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(
|
||||
pixel_values_videos, grid_thw=grid_thw_list
|
||||
)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
|
||||
Reference in New Issue
Block a user