refactor(api): continue decoupling dify_graph from API concerns (#33580)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN-
2026-03-25 20:32:24 +08:00
committed by GitHub
parent b7b9b003c9
commit 56593f20b0
487 changed files with 17999 additions and 9186 deletions

View File

@ -3,9 +3,6 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.model_manager import ModelInstance
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
@ -14,7 +11,7 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from dify_graph.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult
@ -27,10 +24,11 @@ from dify_graph.nodes.llm import (
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
from dify_graph.template_rendering import Jinja2TemplateRenderer
from dify_graph.utils.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
@ -49,17 +47,22 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
class _PassthroughPromptMessageSerializer:
def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any:
_ = model_mode
return list(prompt_messages)
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_prompt_message_serializer: PromptMessageSerializerProtocol
_model_instance: PreparedLLMProtocol
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
credentials_provider: object | None = None,
model_factory: object | None = None,
model_instance: PreparedLLMProtocol,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
template_renderer: Jinja2TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
llm_file_saver: LLMFileSaver,
prompt_message_serializer: PromptMessageSerializerProtocol | None = None,
):
super().__init__(
id=id,
@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
_ = credentials_provider, model_factory, http_client
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver
self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer()
@classmethod
def version(cls):
@ -173,7 +170,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.require_dify_context().user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
@ -209,7 +205,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
category_id = category_id_result
process_data = {
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
"prompts": self._prompt_message_serializer.serialize(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
@ -251,7 +247,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
)
@property
def model_instance(self) -> ModelInstance:
def model_instance(self) -> PreparedLLMProtocol:
return self._model_instance
@classmethod
@ -289,7 +285,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
context: str | None,
) -> int:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
@ -299,7 +295,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
prompt_template=prompt_template,
sys_query="",
sys_files=[],
context=context,
context=context or "",
memory=None,
model_instance=model_instance,
stop=model_instance.stop,
@ -338,7 +334,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
model_mode = LLMMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
@ -354,7 +350,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
if model_mode == LLMMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
@ -385,7 +381,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
elif model_mode == LLMMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,