[VLM] Remove image_input_type from VLM config (#5852)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010
2024-07-02 00:57:09 -07:00
committed by GitHub
parent 2c37540aa6
commit 98d6682cd1
35 changed files with 329 additions and 751 deletions

View File

@ -17,20 +17,18 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoTokenizer, BatchEncoding)
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.config import TokenizerPoolConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.multimodal import MultiModalData
else:
# it will call torch.cuda.device_count()
MultiModalData = None
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
if TYPE_CHECKING:
# it will call torch.cuda.device_count()
from vllm.multimodal import MultiModalDataDict
logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__)
@ -51,14 +49,6 @@ def _read_prompts(filename: str) -> List[str]:
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]
@cached_property
def pixel_values(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")
@cached_property
def image_features(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")
@cached_property
def pil_image(self) -> Image.Image:
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
@ -66,20 +56,8 @@ class ImageAsset:
def for_hf(self) -> Image.Image:
return self.pil_image
def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from vllm.multimodal.image import ImageFeatureData # noqa: F401
from vllm.multimodal.image import ImagePixelData
image_input_type = vision_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if image_input_type == ImageInputType.IMAGE_FEATURES:
return ImageFeatureData(self.image_features)
if image_input_type == ImageInputType.PIXEL_VALUES:
return ImagePixelData(self.pil_image)
raise NotImplementedError
def for_vllm(self) -> Dict[str, Any]:
return {"image": self.pil_image}
class _ImageAssetPrompts(TypedDict):
@ -453,7 +431,7 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[MultiModalData]] = None,
images: Optional[List["MultiModalDataDict"]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
@ -502,7 +480,7 @@ class VllmRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[MultiModalData]] = None,
images: Optional[List["MultiModalDataDict"]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)