refactor(api): continue decoupling dify_graph from API concerns (#33580)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN-
2026-03-25 20:32:24 +08:00
committed by GitHub
parent b7b9b003c9
commit 56593f20b0
487 changed files with 17999 additions and 9186 deletions

View File

@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Tenant
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
@staticmethod
def _get_bound_model_instance(
*,
tenant_id: str,
user_id: str | None,
provider: str,
model_type: ModelType,
model: str,
):
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=provider,
model_type=model_type,
model=model,
)
@classmethod
def invoke_llm(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
)
if isinstance(response, Generator):
@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm with structured output
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
)
@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke text embedding
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_text_embedding(
texts=payload.texts,
user=user_id,
)
response = model_instance.invoke_text_embedding(texts=payload.texts)
return response
@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke rerank
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
docs=payload.docs,
score_threshold=payload.score_threshold,
top_n=payload.top_n,
user=user_id,
)
return response
@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke tts
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_tts(
content_text=payload.content_text,
tenant_id=tenant.id,
voice=payload.voice,
user=user_id,
)
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
def handle() -> Generator[dict, None, None]:
for chunk in response:
@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke speech2text
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
temp.flush()
temp.seek(0)
response = model_instance.invoke_speech2text(
file=temp,
user=user_id,
)
response = model_instance.invoke_speech2text(file=temp)
return {
"result": response,
@ -252,36 +261,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke moderation
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_moderation(
text=payload.text,
user=user_id,
)
response = model_instance.invoke_moderation(text=payload.text)
return {
"result": response,
}
@classmethod
def get_system_model_max_tokens(cls, tenant_id: str) -> int:
def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int:
"""
get system model max tokens
"""
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id)
@classmethod
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int:
"""
get prompt tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
return ModelInvocationUtils.calculate_tokens(
tenant_id=tenant_id,
prompt_messages=prompt_messages,
user_id=user_id,
)
@classmethod
def invoke_system_model(
@ -299,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tool_type=ToolProviderType.PLUGIN,
tool_name="plugin",
prompt_messages=prompt_messages,
caller_user_id=user_id,
)
@classmethod
@ -306,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke summary
"""
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id)
content = payload.text
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
@ -325,6 +337,7 @@ Here is the extra instruction you need to follow:
cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[UserPromptMessage(content=content)],
user_id=user_id,
)
< max_tokens * 0.6
):
@ -337,6 +350,7 @@ Here is the extra instruction you need to follow:
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
UserPromptMessage(content=content),
],
user_id=user_id,
)
def summarize(content: str) -> str:
@ -394,6 +408,7 @@ Here is the extra instruction you need to follow:
cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[UserPromptMessage(content=result)],
user_id=user_id,
)
> max_tokens * 0.7
):

View File

@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
tool_type,
tenant_id,
provider,
tool_name,
tool_parameters,
user_id=user_id,
credential_id=credential_id,
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1