feat: support openllm embedding (#1293)

This commit is contained in:
takatost
2023-10-10 12:09:35 +08:00
committed by GitHub
parent 1d4f019de4
commit 4ab4bcc074
5 changed files with 171 additions and 12 deletions

View File

@ -2,11 +2,13 @@ import json
from typing import Type
from core.helper import encrypter
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
from core.third_party.langchain.llms.openllm import OpenLLM
from models.provider import ProviderType
@ -31,6 +33,8 @@ class OpenLLMProvider(BaseModelProvider):
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = OpenLLMModel
elif model_type== ModelType.EMBEDDINGS:
model_class = OpenLLMEmbedding
else:
raise NotImplementedError
@ -69,14 +73,21 @@ class OpenLLMProvider(BaseModelProvider):
'server_url': credentials['server_url']
}
llm = OpenLLM(
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)
if model_type == ModelType.TEXT_GENERATION:
llm = OpenLLM(
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)
llm("ping")
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = OpenLLMEmbeddings(
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))