refactor(api): continue decoupling dify_graph from API concerns (#33580)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN-
2026-03-25 20:32:24 +08:00
committed by GitHub
parent b7b9b003c9
commit 56593f20b0
487 changed files with 17999 additions and 9186 deletions

View File

@ -1,6 +1,6 @@
import binascii
from collections.abc import Generator, Sequence
from typing import IO
from typing import IO, Any
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
class PluginModelClient(BasePluginClient):
@staticmethod
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {"data": data}
if user_id is not None:
payload["user_id"] = user_id
return payload
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.
@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient):
def get_model_schema(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/schema",
PluginModelSchemaEntity,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient):
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
) -> bool:
"""
validate the credentials of the provider
@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient):
def validate_model_credentials(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient):
def invoke_llm(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type_=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "llm",
"model": model,
@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient):
"stop": stop,
"stream": stream,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient):
def get_llm_num_tokens(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type_=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient):
"prompt_messages": prompt_messages,
"tools": tools,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient):
def invoke_text_embedding(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient):
"texts": texts,
"input_type": input_type,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient):
def invoke_multimodal_embedding(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient):
"documents": documents,
"input_type": input_type,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient):
def get_text_embedding_num_tokens(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"texts": texts,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient):
def invoke_rerank(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "rerank",
"model": model,
@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient):
"score_threshold": score_threshold,
"top_n": top_n,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient):
def invoke_multimodal_rerank(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "rerank",
"model": model,
@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient):
"score_threshold": score_threshold,
"top_n": top_n,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient):
def invoke_tts(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "tts",
"model": model,
@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient):
"content_text": content_text,
"voice": voice,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient):
def get_tts_model_voices(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type_=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "tts",
"model": model,
"credentials": credentials,
"language": language,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient):
def invoke_speech_to_text(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "speech2text",
"model": model,
"credentials": credentials,
"file": binascii.hexlify(file.read()).decode(),
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient):
def invoke_moderation(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type_=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "moderation",
"model": model,
"credentials": credentials,
"text": text,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,