feat: add jina embedding (#1647)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
zxhlyh
2023-11-29 14:58:11 +08:00
committed by GitHub
parent 454577c6b1
commit 451af66be0
22 changed files with 662 additions and 4 deletions

View File

@ -0,0 +1,141 @@
import json
from json import JSONDecodeError
from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
from models.provider import ProviderType
class JinaProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'jina'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'jina-embeddings-v2-base-en',
'name': 'jina-embeddings-v2-base-en',
},
{
'id': 'jina-embeddings-v2-small-en',
'name': 'jina-embeddings-v2-small-en',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.EMBEDDINGS:
model_class = JinaEmbedding
else:
raise NotImplementedError
return model_class
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Jina API Key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
}
embedding = JinaEmbeddings(
model='jina-embeddings-v2-small-en',
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)
def _get_text_generation_model_mode(self, model_name) -> str:
raise NotImplementedError
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
raise NotImplementedError