[Model] Support DP for ViT on Kimi-VL-A3B-Thinking-2506 (#23817)

Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Junhong <liujunhong11@huawei.com>
Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
WeiQing Chen
2025-09-02 00:56:56 +08:00
committed by GitHub
parent cf91a89dd2
commit a0e0efd6bd
6 changed files with 156 additions and 61 deletions

View File

@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
Known supported models:
- Kimi-VL (<gh-pr:23817>)
- Llama4 (<gh-pr:18368>)
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)
- Qwen2.5-VL (<gh-pr:22742>)

View File

@ -636,8 +636,10 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_mrope_vision_model(
vision_model, pixel_values, grid_thw_list)
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
sharded_output = torch.cat(sharded_output, dim=0)
# Check that the world size is setup correctly
@ -691,8 +693,10 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
# Should handle empty input gracefully
with torch.inference_mode():
output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values,
grid_thw_list)
output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
assert len(output) == 0
@ -745,8 +749,10 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
# Should handle uneven distribution without errors
with torch.inference_mode():
output_tuple = run_dp_sharded_mrope_vision_model(
vision_model, pixel_values, grid_thw_list)
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
# Verify output shape is reasonable
merge_factor = vision_model.spatial_merge_size**2

View File

@ -56,6 +56,7 @@ from transformers.activations import GELUActivation
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
@ -76,6 +77,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
@ -93,8 +95,10 @@ class MaxImageTokenMeta:
class KimiVLMultiModalProjector(nn.Module):
def __init__(self, config: KimiVLConfig):
def __init__(self, config: KimiVLConfig, \
use_data_parallel: bool = False, prefix: str = ""):
super().__init__()
self.use_data_parallel = use_data_parallel
self.hidden_size = (config.vision_config.hidden_size *
config.vision_config.merge_kernel_size[0] *
@ -102,20 +106,24 @@ class KimiVLMultiModalProjector(nn.Module):
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size,
eps=1e-5)
self.linear_1 = nn.Linear(self.hidden_size,
self.hidden_size,
bias=True)
self.linear_1 = ReplicatedLinear(self.hidden_size,
self.hidden_size,
bias=True,
prefix=maybe_prefix(
prefix, "linear_1"))
self.linear_2 = ReplicatedLinear(self.hidden_size,
config.text_config.hidden_size,
bias=True,
prefix=maybe_prefix(
prefix, "linear_2"))
self.act = GELUActivation()
self.linear_2 = nn.Linear(self.hidden_size,
config.text_config.hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(image_features).view(
-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states, _ = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states
@ -273,6 +281,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
@ -292,10 +302,17 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
quant_config = vllm_config.quant_config
assert isinstance(config.vision_config, MoonViTConfig)
self.use_data_parallel = model_config.multimodal_config.mm_encoder_tp_mode == "data"
self.hidden_size = config.text_config.hidden_size
self.vision_tower = MoonVitPretrainedModel(config.vision_config,
self.use_data_parallel,
prefix=maybe_prefix(
prefix, "vision_tower"))
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
self.multi_modal_projector = KimiVLMultiModalProjector(
config=config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.quant_config = quant_config
sub_vllm_config = copy.deepcopy(vllm_config)
@ -376,13 +393,19 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = inputs["pixel_values"]
image_grid_hws = inputs["image_grid_hws"]
return self.vision_tower(pixel_values, image_grid_hws)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.vision_tower,
pixel_values,
image_grid_hws.tolist(),
rope_type="rope_2d")
else:
return self.vision_tower(pixel_values, image_grid_hws)
def _process_image_input(self,
image_input: KimiVLImageInputs) -> torch.Tensor:
assert image_input["type"] == "pixel_values"
image_features = self._process_image_pixels(image_input)
assert isinstance(image_features, list)
assert isinstance(image_features, (list, tuple))
lengths = [x.shape[0] for x in image_features]
return self.multi_modal_projector(
torch.cat(image_features)).split(lengths)
@ -496,6 +519,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
expert_params_mapping = []
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}

View File

@ -42,7 +42,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from collections.abc import Sequence
from copy import deepcopy
from functools import cached_property
@ -55,6 +54,8 @@ from transformers.activations import ACT2FN, PytorchGELUTanh
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available():
@ -383,21 +384,30 @@ class MLP2(nn.Module):
bias: whether to use bias in linear layer.
"""
def __init__(self, dims: list[int], activation, bias=True):
def __init__(self,
dims: list[int],
activation,
bias=True,
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
assert len(dims) == 3
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
self.use_data_parallel = use_data_parallel
self.fc0 = ReplicatedLinear(dims[0],
dims[1],
bias=bias,
prefix=maybe_prefix(prefix, "fc0"))
self.fc1 = ReplicatedLinear(dims[1],
dims[2],
bias=bias,
prefix=maybe_prefix(prefix, "fc1"))
self.activation = activation
for m in [self.fc0, self.fc1]:
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc0(x)
x, _ = self.fc0(x)
x = self.activation(x)
return self.fc1(x)
x, _ = self.fc1(x)
return x
class MoonVitEncoderLayer(nn.Module):
@ -407,6 +417,8 @@ class MoonVitEncoderLayer(nn.Module):
num_heads: int,
hidden_dim: int,
mlp_dim: int,
prefix: str = "",
use_data_parallel: bool = False,
*,
attn_implementation: str = "sdpa",
activation=F.gelu,
@ -423,9 +435,19 @@ class MoonVitEncoderLayer(nn.Module):
self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
self.use_data_parallel = use_data_parallel
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim],
activation,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.wqkv = ReplicatedLinear(hidden_dim,
hidden_dim * 3,
bias=attn_bias,
prefix=f"{prefix}.wqkv")
self.wo = ReplicatedLinear(hidden_dim,
hidden_dim,
bias=attn_bias,
prefix=f"{prefix}.wo")
def attention_qkvpacked(
self,
@ -438,7 +460,7 @@ class MoonVitEncoderLayer(nn.Module):
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
cu_seqlens (torch.Tensor):
"""
xqkv = self.wqkv(x)
xqkv, _ = self.wqkv(x)
qkv_shape = xqkv.size()[:-1] + (
3,
@ -457,8 +479,7 @@ class MoonVitEncoderLayer(nn.Module):
xv,
q_cu_seqlens=cu_seqlens,
k_cu_seqlens=cu_seqlens)
attn_out = self.wo(attn_out)
attn_out, _ = self.wo(attn_out)
return attn_out
def forward(
@ -494,13 +515,17 @@ class MoonVitEncoder(nn.Module):
hidden_dim: int,
num_layers: int,
block_cfg: dict,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.rope_2d = Rope2DPosEmb(
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
self.blocks = nn.ModuleList(
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)])
[MoonVitEncoderLayer(use_data_parallel=use_data_parallel, \
prefix=f"{prefix}.blocks.{layer_idx}", \
**block_cfg) for layer_idx in range(num_layers)])
self.final_layernorm = nn.LayerNorm(hidden_dim)
def forward(self, hidden_states: torch.Tensor,
@ -508,10 +533,9 @@ class MoonVitEncoder(nn.Module):
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
grid_hws=grid_hw)
lengths = torch.cat((
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
grid_hw[:, 0] * grid_hw[:, 1],
))
lengths = torch.cat(
(torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
(grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device)))
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
for _, block in enumerate(self.blocks):
@ -587,11 +611,19 @@ class MoonVitPretrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
def __init__(self,
config: MoonViTConfig,
use_data_parallel: bool = False,
prefix: str = "",
*inputs,
**kwargs):
super().__init__(config, *inputs, **kwargs)
config = deepcopy(config)
self.use_data_parallel = use_data_parallel
self.merge_kernel_size = config.merge_kernel_size
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.vit_processing_type = "rope_2d"
self.patch_embed = MoonVisionPatchEmbed(
out_dim=config.hidden_size,
patch_size=config.patch_size,
@ -610,6 +642,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
"attn_bias": True,
"attn_implementation": config._attn_implementation,
},
prefix=f"{prefix}.encoder",
)
def forward(self, pixel_values: torch.Tensor,

View File

@ -1021,8 +1021,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = image_input["pixel_values"]
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list)
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw_list)
@ -1048,8 +1050,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
else:
pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list)
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d")
else:
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw_list)

View File

@ -9,7 +9,7 @@ from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse
from urllib.request import url2pathname
@ -444,7 +444,6 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
Args:
image_input (torch.Tensor): Image input tensor.
vision_model (torch.nn.Module): Vision model.
Returns:
torch.Tensor: Output image embeddings
"""
@ -542,6 +541,8 @@ def run_dp_sharded_mrope_vision_model(
vision_model: torch.nn.Module,
pixel_values: torch.Tensor,
grid_thw_list: list[list[int]],
*,
rope_type: Literal["rope_3d", "rope_2d"],
) -> tuple[torch.Tensor, ...]:
"""Run a vision model with data parallelism (DP) sharding.
The function will shard the input image tensor on the
@ -552,6 +553,10 @@ def run_dp_sharded_mrope_vision_model(
vision_model (torch.nn.Module): Vision model.
pixel_values (torch.Tensor): Image/Video input tensor.
grid_thw_list: List of grid dimensions for each image
rope_type: Type of rope used in the vision model.
Different rope types have different dimension to do ViT.
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
"rope_2d" for 2D rope (e.g., Kimi-VL)
Returns:
torch.Tensor: Output image embeddings
@ -605,8 +610,12 @@ def run_dp_sharded_mrope_vision_model(
device=pixel_values.device,
dtype=pixel_values.dtype)
# embed_dim_reduction_factor = 2 * 2
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
vision_model.spatial_merge_size)
if rope_type == "rope_2d":
embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
vision_model.merge_kernel_size[1])
else:
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
vision_model.spatial_merge_size)
# Find the max length across all ranks
# The output embedding of every DP rank has to be
@ -617,23 +626,42 @@ def run_dp_sharded_mrope_vision_model(
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
# Run the vision model on the local pixel_values_local
if pixel_values_local.shape[0] > 0:
image_embeds_local = vision_model(pixel_values_local,
local_grid_thw_list)
if rope_type == "rope_2d":
if pixel_values_local.shape[0] > 0:
image_embeds_local = vision_model(
pixel_values_local, torch.tensor(local_grid_thw_list))
if isinstance(image_embeds_local, list):
image_embeds_local = torch.cat(image_embeds_local, dim=0)
else:
out_dim = getattr(vision_model.config, "hidden_size", None)
image_embeds_local = torch.empty(
(0, embed_dim_reduction_factor, out_dim),
device=pixel_values.device,
dtype=pixel_values.dtype)
else:
# Handle empty case
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
device=pixel_values.device,
dtype=pixel_values.dtype)
if pixel_values_local.shape[0] > 0:
image_embeds_local = vision_model(pixel_values_local,
local_grid_thw_list)
else:
# Handle empty case
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
device=pixel_values.device,
dtype=pixel_values.dtype)
# Pad the output based on max_len_per_rank
# for tensor_model_parallel_all_gather to work
current_len = image_embeds_local.shape[0]
if current_len < max_len_per_rank:
padding_size = max_len_per_rank - current_len
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
dtype=image_embeds_local.dtype,
device=image_embeds_local.device)
if rope_type == "rope_2d":
padding = torch.empty((padding_size, image_embeds_local.shape[1],
image_embeds_local.shape[2]),
dtype=image_embeds_local.dtype,
device=image_embeds_local.device)
else:
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
dtype=image_embeds_local.dtype,
device=image_embeds_local.device)
image_embeds_local_padded = torch.cat([image_embeds_local, padding],
dim=0)
else:
@ -674,7 +702,6 @@ def run_dp_sharded_mrope_vision_model(
embed_start:embed_start + img_patches]
embed_start += img_patches
current_idx += count
out_embeddings = tuple(embed for embed in original_order_embeddings
if embed is not None)
assert len(out_embeddings) == len(