[VLM] Remove image_input_type from VLM config (#5852)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@ -17,20 +17,18 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
||||
AutoTokenizer, BatchEncoding)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalData
|
||||
else:
|
||||
# it will call torch.cuda.device_count()
|
||||
MultiModalData = None
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import cuda_device_count_stateless, is_cpu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# it will call torch.cuda.device_count()
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
@ -51,14 +49,6 @@ def _read_prompts(filename: str) -> List[str]:
|
||||
class ImageAsset:
|
||||
name: Literal["stop_sign", "cherry_blossom"]
|
||||
|
||||
@cached_property
|
||||
def pixel_values(self) -> torch.Tensor:
|
||||
return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")
|
||||
|
||||
@cached_property
|
||||
def image_features(self) -> torch.Tensor:
|
||||
return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")
|
||||
|
||||
@cached_property
|
||||
def pil_image(self) -> Image.Image:
|
||||
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
|
||||
@ -66,20 +56,8 @@ class ImageAsset:
|
||||
def for_hf(self) -> Image.Image:
|
||||
return self.pil_image
|
||||
|
||||
def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from vllm.multimodal.image import ImageFeatureData # noqa: F401
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
image_input_type = vision_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if image_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
return ImageFeatureData(self.image_features)
|
||||
if image_input_type == ImageInputType.PIXEL_VALUES:
|
||||
return ImagePixelData(self.pil_image)
|
||||
|
||||
raise NotImplementedError
|
||||
def for_vllm(self) -> Dict[str, Any]:
|
||||
return {"image": self.pil_image}
|
||||
|
||||
|
||||
class _ImageAssetPrompts(TypedDict):
|
||||
@ -453,7 +431,7 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[List[MultiModalData]] = None,
|
||||
images: Optional[List["MultiModalDataDict"]] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
@ -502,7 +480,7 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional[List[MultiModalData]] = None,
|
||||
images: Optional[List["MultiModalDataDict"]] = None,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, greedy_params, images=images)
|
||||
|
||||
Reference in New Issue
Block a user