mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
refactor(api): continue decoupling dify_graph from API concerns (#33580)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import IO
|
||||
from typing import IO, Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class PluginModelClient(BasePluginClient):
|
||||
@staticmethod
|
||||
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"data": data}
|
||||
if user_id is not None:
|
||||
payload["user_id"] = user_id
|
||||
return payload
|
||||
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_model_schema(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/schema",
|
||||
PluginModelSchemaEntity,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient):
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_llm(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type_=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "llm",
|
||||
"model": model,
|
||||
@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type_=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"prompt_messages": prompt_messages,
|
||||
"tools": tools,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"texts": texts,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"documents": documents,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type_=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_tts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"content_text": content_text,
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type_=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "speech2text",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"file": binascii.hexlify(file.read()).decode(),
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_moderation(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type_=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "moderation",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
|
||||
Reference in New Issue
Block a user