mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 13:45:57 +08:00
refactor(api): continue decoupling dify_graph from API concerns (#33580)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@ -3,9 +3,6 @@ import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
@ -14,7 +11,7 @@ from dify_graph.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from dify_graph.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
@ -27,10 +24,11 @@ from dify_graph.nodes.llm import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
llm_utils,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from dify_graph.template_rendering import Jinja2TemplateRenderer
|
||||
from dify_graph.utils.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .exc import InvalidModelTypeError
|
||||
@ -49,17 +47,22 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class _PassthroughPromptMessageSerializer:
|
||||
def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any:
|
||||
_ = model_mode
|
||||
return list(prompt_messages)
|
||||
|
||||
|
||||
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_file_outputs: list["File"]
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: ModelInstance
|
||||
_prompt_message_serializer: PromptMessageSerializerProtocol
|
||||
_model_instance: PreparedLLMProtocol
|
||||
_memory: PromptMessageMemory | None
|
||||
_template_renderer: TemplateRenderer
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
credentials_provider: object | None = None,
|
||||
model_factory: object | None = None,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
http_client: HttpClientProtocol,
|
||||
template_renderer: TemplateRenderer,
|
||||
template_renderer: Jinja2TemplateRenderer,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
llm_file_saver: LLMFileSaver,
|
||||
prompt_message_serializer: PromptMessageSerializerProtocol | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
_ = credentials_provider, model_factory, http_client
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
self._template_renderer = template_renderer
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer()
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
@ -173,7 +170,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
@ -209,7 +205,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
category_id = category_id_result
|
||||
process_data = {
|
||||
"model_mode": node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
"prompts": self._prompt_message_serializer.serialize(
|
||||
model_mode=node_data.model.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
@ -251,7 +247,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
)
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
def model_instance(self) -> PreparedLLMProtocol:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
@ -289,7 +285,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
@ -299,7 +295,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
sys_files=[],
|
||||
context=context,
|
||||
context=context or "",
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
stop=model_instance.stop,
|
||||
@ -338,7 +334,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
model_mode = LLMMode(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
categories = []
|
||||
for class_ in classes:
|
||||
@ -354,7 +350,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
||||
)
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
if model_mode == LLMMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
|
||||
)
|
||||
@ -385,7 +381,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
elif model_mode == LLMMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
|
||||
histories=memory_str,
|
||||
|
||||
Reference in New Issue
Block a user