[Model] Support Skywork-R1V (#15397)
Signed-off-by: jiacai.liu <932997367@qq.com> Co-authored-by: jiacai.liu <932997367@qq.com>
This commit is contained in:
@ -474,6 +474,20 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
|
||||
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
|
||||
),
|
||||
"skywork_r1v": VLMTestInfo(
|
||||
models=["Skywork/Skywork-R1V-38B"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
|
||||
"cherry_blossom": "<image>\nWhat is the season?",
|
||||
}),
|
||||
multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", # noqa: E501
|
||||
max_model_len=4096,
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=80)],
|
||||
),
|
||||
### Tensor parallel / multi-gpu broadcast tests
|
||||
"chameleon-broadcast": VLMTestInfo(
|
||||
models=["facebook/chameleon-7b"],
|
||||
|
||||
@ -376,6 +376,63 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for SkyworkR1V."""
|
||||
|
||||
class SkyworkR1VProcessor:
|
||||
"""A simple processor for SkyworkR1V."""
|
||||
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
|
||||
trust_remote_code=True)
|
||||
self.vision_config = self.config.vision_config
|
||||
self.use_thumbnail = self.config.use_thumbnail
|
||||
self.min_num = self.config.min_dynamic_patch
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Union[Image, list[Image]],
|
||||
**kwargs):
|
||||
from vllm.model_executor.models.skyworkr1v import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START,
|
||||
image_to_pixel_values_skyworkr1v)
|
||||
images = [images] if isinstance(images, Image) else images
|
||||
pixel_values = [
|
||||
image_to_pixel_values_skyworkr1v(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=self.min_num,
|
||||
max_num=self.max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
num_patches_list = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values
|
||||
]
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||
* num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
|
||||
"<IMG_CONTEXT>")
|
||||
hf_model.model.img_context_token_id = img_context_token_id
|
||||
hf_model.processor = SkyworkR1VProcessor(hf_model)
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
hf_model.model.language_model.get_output_embeddings()
|
||||
hf_model.model.generate = types.MethodType(_internvl_generate,
|
||||
hf_model.model)
|
||||
return hf_model
|
||||
|
||||
|
||||
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for InternVL."""
|
||||
|
||||
|
||||
@ -262,22 +262,23 @@ def _test_processing_correctness_mistral(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-Llama3-V-2_5",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
"openbmb/MiniCPM-V-2_6",
|
||||
"allenai/Molmo-7B-D-0924",
|
||||
"allenai/Molmo-7B-O-0924",
|
||||
"nvidia/NVLM-D-72B",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
"Qwen/Qwen-VL-Chat",
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"Skywork/Skywork-R1V-38B",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
"openai/whisper-large-v3",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
|
||||
@ -294,6 +294,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||
min_transformers_version="4.49"), # noqa: E501
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
|
||||
Reference in New Issue
Block a user