feat: add xinference embedding model support (#930)

This commit is contained in:
takatost
2023-08-20 19:35:07 +08:00
committed by GitHub
parent 18dd0d569d
commit 25264e7852
3 changed files with 94 additions and 0 deletions

View File

@ -0,0 +1,26 @@
from langchain.embeddings import XinferenceEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class XinferenceEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = XinferenceEmbeddings(
**credentials,
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
else:
return ex

View File

@ -4,6 +4,7 @@ from typing import Type
from langchain.llms import Xinference
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@ -32,6 +33,8 @@ class XinferenceProvider(BaseModelProvider):
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = XinferenceModel
elif model_type == ModelType.EMBEDDINGS:
model_class = XinferenceEmbedding
else:
raise NotImplementedError