[Model] Multi-input support for LLaVA (#8238)

This commit is contained in:
Cyrus Leung
2024-09-07 10:57:24 +08:00
committed by GitHub
parent 41e95c5247
commit 2f707fcb35
10 changed files with 176 additions and 45 deletions

View File

@ -278,7 +278,7 @@ class HfRunner:
def generate(
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
@ -314,7 +314,7 @@ class HfRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
@ -351,7 +351,7 @@ class HfRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
all_logprobs: List[List[torch.Tensor]] = []
@ -433,8 +433,8 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
@ -671,7 +671,7 @@ class VllmRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)