mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
feat: support openllm embedding (#1293)
This commit is contained in:
@ -2,11 +2,13 @@ import json
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
@ -31,6 +33,8 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = OpenLLMModel
|
||||
elif model_type== ModelType.EMBEDDINGS:
|
||||
model_class = OpenLLMEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -69,14 +73,21 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
'server_url': credentials['server_url']
|
||||
}
|
||||
|
||||
llm = OpenLLM(
|
||||
llm_kwargs={
|
||||
'max_new_tokens': 10
|
||||
},
|
||||
**credential_kwargs
|
||||
)
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
llm = OpenLLM(
|
||||
llm_kwargs={
|
||||
'max_new_tokens': 10
|
||||
},
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
llm("ping")
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
embedding = OpenLLMEmbeddings(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
embedding.embed_query("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user