Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
committed by
GitHub
parent
49628fe13e
commit
214efc2c3c
@ -265,6 +265,7 @@ class HfRunner:
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
is_embedding_model: bool = False,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||
postprocess_inputs: Callable[..., BatchEncoding] = identity,
|
||||
@ -282,6 +283,14 @@ class HfRunner:
|
||||
device="cpu",
|
||||
trust_remote_code=True,
|
||||
).to(dtype=torch_dtype))
|
||||
elif is_cross_encoder:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import CrossEncoder
|
||||
self.model = CrossEncoder(model_name,
|
||||
device="cpu",
|
||||
trust_remote_code=True)
|
||||
self.model.model = self.wrap_device(self.model.model)\
|
||||
.to(dtype=torch_dtype)
|
||||
else:
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
self.model = self.wrap_device(
|
||||
@ -625,6 +634,9 @@ class HfRunner:
|
||||
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
||||
return self.model.encode(prompts)
|
||||
|
||||
def predict(self, prompts: List[List[str]]) -> torch.Tensor:
|
||||
return self.model.predict(prompts, convert_to_tensor=True)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@ -898,6 +910,14 @@ class VllmRunner:
|
||||
req_outputs = self.model.encode(inputs)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[str, List[str]],
|
||||
text_2: Union[str, List[str]],
|
||||
) -> List[List[float]]:
|
||||
req_outputs = self.model.score(text_1, text_2)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
Reference in New Issue
Block a user