[VLM] Support HF format Phi-4-MM model (#17121)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@ -614,6 +614,7 @@ Specified using `--task generate`.
|
||||
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | ✅︎ |
|
||||
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -190,6 +190,37 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_phi4_multimodal(question: str, audio_count: int) -> ModelRequestData:
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process audio inputs.
|
||||
"""
|
||||
model_path = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
|
||||
)
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
speech_lora_path = os.path.join(model_path, "speech-lora")
|
||||
placeholders = "<|audio|>" * audio_count
|
||||
|
||||
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
max_model_len=12800,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompts,
|
||||
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
# Qwen2-Audio
|
||||
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
@ -303,6 +334,7 @@ model_example_map = {
|
||||
"granite_speech": run_granite_speech,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
"phi4_multimodal": run_phi4_multimodal,
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"ultravox": run_ultravox,
|
||||
|
||||
@ -1097,6 +1097,41 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# HF format Phi-4-multimodal-instruct
|
||||
def run_phi4_multimodal(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process image inputs.
|
||||
"""
|
||||
assert modality == "image"
|
||||
model_path = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
|
||||
)
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||
prompts = [
|
||||
f"<|user|><|image|>{question}<|end|><|assistant|>" for question in questions
|
||||
]
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
max_model_len=5120,
|
||||
max_num_seqs=2,
|
||||
max_num_batched_tokens=12800,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||
mm_processor_kwargs={"dynamic_hd": 16},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
# Pixtral HF-format
|
||||
def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1356,6 +1391,7 @@ model_example_map = {
|
||||
"paligemma2": run_paligemma2,
|
||||
"phi3_v": run_phi3v,
|
||||
"phi4_mm": run_phi4mm,
|
||||
"phi4_multimodal": run_phi4_multimodal,
|
||||
"pixtral_hf": run_pixtral_hf,
|
||||
"qwen_vl": run_qwen_vl,
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
|
||||
@ -760,6 +760,40 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_phi4_multimodal(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process multi images inputs.
|
||||
"""
|
||||
|
||||
model_path = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
|
||||
)
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||
mm_processor_kwargs={"dynamic_hd": 4},
|
||||
)
|
||||
|
||||
placeholders = "<|image|>" * len(image_urls)
|
||||
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen-VL-Chat"
|
||||
engine_args = EngineArgs(
|
||||
@ -988,6 +1022,7 @@ model_example_map = {
|
||||
"ovis": load_ovis,
|
||||
"phi3_v": load_phi3v,
|
||||
"phi4_mm": load_phi4mm,
|
||||
"phi4_multimodal": load_phi4_multimodal,
|
||||
"pixtral_hf": load_pixtral_hf,
|
||||
"qwen_vl_chat": load_qwen_vl_chat,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
|
||||
252
tests/models/multimodal/generation/test_phi4_multimodal.py
Normal file
252
tests/models/multimodal/generation/test_phi4_multimodal.py
Normal file
@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import librosa
|
||||
import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
|
||||
PromptImageInput, VllmRunner)
|
||||
from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
"cherry_blossom":
|
||||
"<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
})
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501
|
||||
|
||||
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct",
|
||||
revision="refs/pr/70")
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||
speech_question = os.path.join(model_path, "examples",
|
||||
"what_is_shown_in_this_image.wav")
|
||||
models = [model_path]
|
||||
|
||||
target_dtype = "half"
|
||||
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
inputs: Sequence[tuple[list[str], PromptImageInput,
|
||||
Optional[PromptAudioInput]]],
|
||||
model: str,
|
||||
*,
|
||||
max_model_len: int,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test are from IMAGE_ASSETS.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(
|
||||
model,
|
||||
task="generate",
|
||||
max_model_len=max_model_len,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI
|
||||
enforce_eager=True,
|
||||
trust_remote_code=False,
|
||||
) as vllm_model:
|
||||
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
audios=audios,
|
||||
lora_request=lora_request)
|
||||
for prompts, images, audios in inputs
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_model.model.load_adapter(
|
||||
vision_lora_path,
|
||||
adapter_name="vision",
|
||||
)
|
||||
hf_processor = hf_model.processor
|
||||
eos_token_id = hf_processor.tokenizer.eos_token_id
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
audios=audios,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, images, audios in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||
vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_model_len", [12800])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_model_len: int, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
None,
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_image,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
# [],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_model_len", [25600])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||
size_factors, dtype: str, max_model_len: int,
|
||||
max_tokens: int, num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
(
|
||||
[HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors],
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_model_len", [12800])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
|
||||
max_model_len: int, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
|
||||
# use the example speech question so that the model outputs are reasonable
|
||||
audio = librosa.load(speech_question, sr=16000)
|
||||
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
||||
|
||||
inputs_vision_speech = [
|
||||
(
|
||||
["<|user|><|image|><|audio|><|end|><|assistant|>"],
|
||||
[image],
|
||||
[audio],
|
||||
),
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_vision_speech,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
@ -41,12 +41,18 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
|
||||
|
||||
|
||||
def _test_processing_correctness(
|
||||
model_id: str,
|
||||
model_id_or_arch: str,
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
if model_id_or_arch in HF_EXAMPLE_MODELS.get_supported_archs():
|
||||
# Use model architecture to get the default model id
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_id_or_arch)
|
||||
model_id = model_info.default
|
||||
else:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id_or_arch)
|
||||
model_id = model_id_or_arch
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
@ -58,7 +64,7 @@ def _test_processing_correctness(
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
seed=0,
|
||||
dtype="auto",
|
||||
revision=None,
|
||||
revision=model_info.revision,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
|
||||
@ -331,6 +337,28 @@ def test_processing_correctness(
|
||||
)
|
||||
|
||||
|
||||
# Phi4MultimodalForCausalLM share same model repo with original format
|
||||
# Phi4MMForCausalLM, so we add it as a separate test case
|
||||
# Remove this test after conversion PR merged:
|
||||
# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/70
|
||||
@pytest.mark.parametrize("model_arch", ["Phi4MultimodalForCausalLM"])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||
def test_processing_correctness_phi4_multimodal(
|
||||
model_arch: str,
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
_test_processing_correctness(
|
||||
model_arch,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
)
|
||||
|
||||
|
||||
def _assert_inputs_equal(
|
||||
a: MultiModalInputs,
|
||||
b: MultiModalInputs,
|
||||
|
||||
@ -433,6 +433,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
||||
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
||||
trust_remote_code=True),
|
||||
"Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501
|
||||
revision="refs/pr/70"),
|
||||
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
||||
tokenizer_mode="mistral"),
|
||||
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
|
||||
|
||||
1455
vllm/model_executor/models/phi4_multimodal.py
Normal file
1455
vllm/model_executor/models/phi4_multimodal.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -223,6 +223,8 @@ _MULTIMODAL_MODELS = {
|
||||
"Ovis": ("ovis", "Ovis"),
|
||||
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
||||
"Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501
|
||||
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
||||
"QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
@ -231,7 +233,6 @@ _MULTIMODAL_MODELS = {
|
||||
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
||||
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
||||
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
||||
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||
|
||||
@ -295,7 +295,7 @@ def cached_tokenizer_from_config(
|
||||
return cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
tokenizer_revision=model_config.tokenizer_revision,
|
||||
revision=model_config.tokenizer_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user