[Core] Support image processor (#4197)
This commit is contained in:
@ -3,33 +3,36 @@ import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
|
||||
|
||||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
||||
# You can use `.buildkite/download-images.sh` to download them
|
||||
|
||||
|
||||
def run_llava_pixel_values():
|
||||
def run_llava_pixel_values(*, disable_image_processor: bool = False):
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
image_input_type="pixel_values",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,3,336,336",
|
||||
image_feature_size=576,
|
||||
disable_image_processor=disable_image_processor,
|
||||
)
|
||||
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# This should be provided by another online or offline component.
|
||||
image = torch.load("images/stop_sign_pixel_values.pt")
|
||||
if disable_image_processor:
|
||||
image = torch.load("images/stop_sign_pixel_values.pt")
|
||||
else:
|
||||
image = Image.open("images/stop_sign.jpg")
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt":
|
||||
prompt,
|
||||
"multi_modal_data":
|
||||
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": ImagePixelData(image),
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
@ -49,15 +52,13 @@ def run_llava_image_features():
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# This should be provided by another online or offline component.
|
||||
image = torch.load("images/stop_sign_image_features.pt")
|
||||
image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt":
|
||||
prompt,
|
||||
"multi_modal_data":
|
||||
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": ImageFeatureData(image),
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
Reference in New Issue
Block a user