refactor: move workflow package to dify_graph (#32844)

This commit is contained in:
-LAN-
2026-03-02 18:42:30 +08:00
committed by GitHub
parent 9c33923985
commit c917838f9c
613 changed files with 2008 additions and 2012 deletions

View File

@ -0,0 +1,4 @@
from .entities import QuestionClassifierNodeData
from .question_classifier_node import QuestionClassifierNode
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"]

View File

@ -0,0 +1,28 @@
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.nodes.base import BaseNodeData
from dify_graph.nodes.llm import ModelConfig, VisionConfig
class ClassConfig(BaseModel):
id: str
name: str
class QuestionClassifierNodeData(BaseNodeData):
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]
instruction: str | None = None
memory: MemoryConfig | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property
def structured_output_enabled(self) -> bool:
# NOTE(QuantumGhost): Temporary workaround for issue #20725
# (https://github.com/langgenius/dify/issues/20725).
#
# The proper fix would be to make `QuestionClassifierNode` inherit
# from `BaseNode` instead of `LLMNode`.
return False

View File

@ -0,0 +1,6 @@
class QuestionClassifierNodeError(ValueError):
"""Base class for QuestionClassifierNode errors."""
class InvalidModelTypeError(QuestionClassifierNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -0,0 +1,388 @@
import json
import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities import GraphInitParams
from dify_graph.enums import (
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
from .template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
QUESTION_CLASSIFIER_SYSTEM_PROMPT,
QUESTION_CLASSIFIER_USER_PROMPT_1,
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
if TYPE_CHECKING:
from dify_graph.file.models import File
from dify_graph.runtime import GraphRuntimeState
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = NodeType.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls):
return "1"
def _run(self):
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
# fetch prompt messages
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query or "",
model_instance=model_instance,
context="",
)
prompt_template = self._get_prompt_template(
node_data=node_data,
query=query or "",
memory=memory,
max_token_limit=rest_token,
)
# Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...).
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
model_instance=model_instance,
stop=model_instance.stop,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
try:
# handle invoke result
generator = LLMNode.invoke_llm(
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
rendered_classes = [
c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes
]
category_name = rendered_classes[0].name
category_id = rendered_classes[0].id
if "<think>" in result_text:
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
category_id_result = result_text_json["category_id"]
classes = rendered_classes
classes_map = {class_.id: class_.name for class_ in classes}
category_ids = [_class.id for _class in classes]
if category_id_result in category_ids:
category_name = classes_map[category_id_result]
category_id = category_id_result
process_data = {
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_instance.provider,
"model_name": model_instance.model_name,
}
outputs = {
"class_name": category_name,
"class_id": category_id,
"usage": jsonable_encoder(usage),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=process_data,
outputs=outputs,
edge_source_handle=category_id,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector}
variable_selectors: list[VariableSelector] = []
if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters (not used in this implementation).
:return:
"""
# filters parameter is not used in this node type
return {"type": "question-classifier", "config": {"instructions": ""}}
def _calculate_rest_token(
self,
node_data: QuestionClassifierNodeData,
query: str,
model_instance: ModelInstance,
context: str | None,
) -> int:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages, _ = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
sys_files=[],
context=context,
memory=None,
model_instance=model_instance,
stop=model_instance.stop,
memory_config=node_data.memory,
vision_enabled=False,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
rest_tokens = 2000
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_instance.parameters.get(parameter_rule.name)
or model_instance.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
category = {"category_id": class_.id, "category_name": class_.name}
categories.append(category)
instruction = node_data.instruction or ""
input_text = query
memory_str = ""
if memory:
memory_str = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
input_text=input_text,
categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction,
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,
input_text=input_text,
categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction,
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@ -0,0 +1,76 @@
QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
### Job Description',
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
""" # noqa: E501
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
"categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}],
"classification_instructions": ["classify the text based on the feedback provided by customer"]}
""" # noqa: E501
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
```json
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
"category_id": "f5660049-284f-41a7-b301-fd24176a711c",
"category_name": "Customer Service"}
```
"""
QUESTION_CLASSIFIER_USER_PROMPT_2 = """
{"input_text": ["bad service, slow to bring the food"],
"categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}],
"classification_instructions": []}
""" # noqa: E501
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
```json
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f",
"category_name": "Experience"}
```
"""
QUESTION_CLASSIFIER_USER_PROMPT_3 = """
{{"input_text": ["{input_text}"],
"categories": {categories},
"classification_instructions": ["{classification_instructions}"]}}
"""
QUESTION_CLASSIFIER_COMPLETION_PROMPT = """
### Job Description
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}}
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
</example>
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### User Input
{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}}
### Assistant Output
""" # noqa: E501