[Model][VLM] Add Qwen2-VL model support (#7905)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Yang Fan
2024-09-12 00:31:19 +08:00
committed by GitHub
parent cea95dfb94
commit 3b7fea770f
14 changed files with 1531 additions and 31 deletions

View File

@ -179,6 +179,23 @@ def run_qwen_vl(question):
return llm, prompt, stop_token_ids
# Qwen2-VL
def run_qwen2_vl(question):
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
model=model_name,
max_num_seqs=5,
)
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
@ -191,6 +208,7 @@ model_example_map = {
"blip-2": run_blip2,
"internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
}

View File

@ -6,7 +6,7 @@ by the model.
from argparse import Namespace
from typing import List
from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.multimodal.utils import fetch_image
@ -30,7 +30,7 @@ def load_phi3v(question, image_urls: List[str]):
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompt, stop_token_ids, None
def load_internvl(question, image_urls: List[str]):
@ -60,18 +60,72 @@ def load_internvl(question, image_urls: List[str]):
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids
return llm, prompt, stop_token_ids, None
def load_qwen2_vl(question, image_urls: List[str]):
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
process_vision_info = None
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
model=model_name,
max_num_seqs=5,
max_model_len=32768 if process_vision_info is None else 4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages)
return llm, prompt, stop_token_ids, image_data
model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
"qwen2_vl": load_qwen2_vl,
}
def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids = model_example_map[model](question,
image_urls)
llm, prompt, stop_token_ids, image_data = model_example_map[model](
question, image_urls)
if image_data is None:
image_data = [fetch_image(url) for url in image_urls]
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
@ -81,7 +135,7 @@ def run_generate(model, question: str, image_urls: List[str]):
{
"prompt": prompt,
"multi_modal_data": {
"image": [fetch_image(url) for url in image_urls]
"image": image_data
},
},
sampling_params=sampling_params)
@ -92,7 +146,7 @@ def run_generate(model, question: str, image_urls: List[str]):
def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids = model_example_map[model](question, image_urls)
llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,