Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -1,3 +1,3 @@
from .enums import NodeType
from core.workflow.enums import NodeType
__all__ = ["NodeType"]

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
@ -9,16 +9,12 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
@ -29,17 +25,25 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
@ -57,19 +61,23 @@ from .exc import (
ToolFileNotFoundError,
)
if TYPE_CHECKING:
from core.agent.strategy.plugin import PluginAgentStrategy
from core.plugin.entities.request import InvokeCredentials
class AgentNode(BaseNode):
class AgentNode(Node):
"""
Agent Node
"""
_node_type = NodeType.AGENT
node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -78,7 +86,7 @@ class AgentNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -91,7 +99,9 @@ class AgentNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
@ -99,12 +109,12 @@ class AgentNode(BaseNode):
agent_strategy_name=self._node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
)
),
)
return
@ -139,8 +149,8 @@ class AgentNode(BaseNode):
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
@ -158,16 +168,16 @@ class AgentNode(BaseNode):
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.type_,
node_id=self.node_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
@ -181,7 +191,7 @@ class AgentNode(BaseNode):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: PluginAgentStrategy,
strategy: "PluginAgentStrategy",
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -320,7 +330,7 @@ class AgentNode(BaseNode):
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
@ -339,10 +349,11 @@ class AgentNode(BaseNode):
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
) -> "InvokeCredentials":
"""
Generate credentials based on the given agent parameters.
"""
from core.plugin.entities.request import InvokeCredentials
credentials = InvokeCredentials()
@ -388,6 +399,8 @@ class AgentNode(BaseNode):
Get agent strategy icon
:return:
"""
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
try:
@ -401,7 +414,7 @@ class AgentNode(BaseNode):
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
@ -450,7 +463,9 @@ class AgentNode(BaseNode):
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _filter_mcp_type_tool(
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
@ -473,11 +488,13 @@ class AgentNode(BaseNode):
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator:
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
@ -491,7 +508,7 @@ class AgentNode(BaseNode):
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
@ -553,7 +570,11 @@ class AgentNode(BaseNode):
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
@ -564,13 +585,17 @@ class AgentNode(BaseNode):
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None:
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
@ -587,8 +612,10 @@ class AgentNode(BaseNode):
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
@ -639,7 +666,7 @@ class AgentNode(BaseNode):
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
@ -652,7 +679,7 @@ class AgentNode(BaseNode):
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
if log.message_id == agent_log.message_id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
@ -673,7 +700,7 @@ class AgentNode(BaseNode):
for log in agent_logs:
json_output.append(
{
"id": log.id,
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
@ -689,8 +716,24 @@ class AgentNode(BaseNode):
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,

View File

@ -1,4 +1,4 @@
from enum import Enum, StrEnum
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()
class AgentOldVersionModelFeatures(StrEnum):
@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

View File

@ -1,6 +1,3 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
def __init__(self, strategy_name: str, provider_name: str | None = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
def __init__(self, message: str, parameter_name: str | None = None):
self.parameter_name = parameter_name
super().__init__(message)
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
def __init__(self, message: str, variable_name: str | None = None):
self.variable_name = variable_name
super().__init__(message)
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
def __init__(self, message: str, file_id: str | None = None):
self.file_id = file_id
super().__init__(message)
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
def __init__(self, message: str, conversation_id: str | None = None):
self.conversation_id = conversation_id
super().__init__(message)
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
variable_name: str | None = None,
expected_type: str | None = None,
actual_type: str | None = None,
):
self.variable_name = variable_name
self.expected_type = expected_type

View File

@ -1,4 +0,0 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]

View File

@ -1,31 +1,26 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base import BaseNode
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER
class AnswerNode(Node):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -34,7 +29,7 @@ class AnswerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -48,35 +43,29 @@ class AnswerNode(BaseNode):
return "1"
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = ""
files = []
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
variable = self.graph_runtime_state.variable_pool.get(value_selector)
if variable:
if isinstance(variable, FileSegment):
files.append(variable.value)
elif isinstance(variable, ArrayFileSegment):
files.extend(variable.value)
answer += variable.markdown
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
files = self._extract_files_from_segments(segments.value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
)
def _extract_files_from_segments(self, segments: Sequence[Segment]):
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
This method flattens all files into a single list.
"""
files = []
for segment in segments:
if isinstance(segment, FileSegment):
# Single file - wrap in list for consistency
files.append(segment.value)
elif isinstance(segment, ArrayFileSegment):
# Multiple files - extend the list
files.extend(segment.value)
return files
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -96,3 +85,12 @@ class AnswerNode(BaseNode):
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this Answer node
"""
return Template.from_answer_template(self._node_data.answer)

View File

@ -1,174 +0,0 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping,
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
generate_routes: list[GenerateRouteChunk] = []
for part in template.split("Ω"):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
else:
generate_routes.append(TextGenerateRouteChunk(text=part))
return generate_routes
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace("{{", "").replace("}}", "")
return part.startswith("{{") and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(
cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(
cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.LOOP,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)

View File

@ -1,199 +0,0 @@
import logging
from collections.abc import Generator
from typing import cast
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunExceptionEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id in self.generate_routes.answer_generate_route:
self.route_position[answer_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id or event.in_loop_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_answer_node_ids
)
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
else:
yield event
def reset(self) -> None:
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id in self.route_position:
# all depends on answer node id not in rest node ids
if event.route_node_state.node_id != answer_node_id and (
answer_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
)
):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get(value_selector)
if value is None:
break
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=list(value_selector),
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies
edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids

View File

@ -1,109 +0,0 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids: list[str] = []
unreachable_first_node_ids: list[str] = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning("node %s has no edge mapping", finished_node_id)
return
for edge in self.graph.edge_mapping[finished_node_id]:
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
# remove unreachable nodes
# FIXME: because of the code branch can combine directly, so for answer node
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
# if "answer" in ids:
# continue
# else:
# reachable_node_ids.extend(ids)
# The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included.
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids)
else:
# if the condition edge in parallel, and the target node is not in parallel, we should not remove it
# Issues: #13626
if (
finished_node_id in self.graph.node_parallel_mapping
and edge.target_node_id not in self.graph.node_parallel_mapping
):
continue
unreachable_first_node_ids.append(edge.target_node_id)
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
if node_id not in self.rest_node_ids:
self.rest_node_ids.append(node_id)
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
# Only follow edges that match the branch_identify or have no run_condition
if edge.run_condition and edge.run_condition.branch_identify:
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
if node_id in reachable_node_ids:
return
self.rest_node_ids.remove(node_id)
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto
from pydantic import BaseModel, Field
@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
Generate Route Chunk.
"""
class ChunkType(Enum):
VAR = "var"
TEXT = "text"
class ChunkType(StrEnum):
VAR = auto()
TEXT = auto()
type: ChunkType = Field(..., description="generate route chunk type")

View File

@ -1,11 +1,9 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .node import BaseNode
__all__ = [
"BaseIterationNodeData",
"BaseIterationState",
"BaseLoopNodeData",
"BaseLoopState",
"BaseNode",
"BaseNodeData",
]

View File

@ -1,12 +1,37 @@
import json
from abc import ABC
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Optional, Union
from typing import Any, Union
from pydantic import BaseModel, model_validator
from core.workflow.nodes.base.exc import DefaultValueTypeError
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.enums import ErrorStrategy
from .exc import DefaultValueTypeError
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: Sequence[str]
class DefaultValueType(StrEnum):
@ -19,16 +44,13 @@ class DefaultValueType(StrEnum):
ARRAY_FILES = "array[file]"
NumberType = Union[int, float]
class DefaultValue(BaseModel):
value: Any
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str) -> Any:
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
@ -51,9 +73,6 @@ class DefaultValue(BaseModel):
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
if self.type is None:
raise DefaultValueTypeError("type field is required")
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {
@ -61,7 +80,7 @@ class DefaultValue(BaseModel):
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": NumberType,
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
@ -70,7 +89,7 @@ class DefaultValue(BaseModel):
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": NumberType,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
@ -107,24 +126,12 @@ class DefaultValue(BaseModel):
return self
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
desc: str | None = None
version: str = "1"
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
@property
@ -135,7 +142,7 @@ class BaseNodeData(ABC, BaseModel):
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str | None = None
class BaseIterationState(BaseModel):
@ -150,7 +157,7 @@ class BaseIterationState(BaseModel):
class BaseLoopNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str | None = None
class BaseLoopState(BaseModel):

View File

@ -1,81 +1,175 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from functools import singledispatchmethod
from typing import Any, ClassVar
from uuid import uuid4
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import (
AgentLogEvent,
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
LoopFailedEvent,
LoopNextEvent,
LoopStartedEvent,
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
from .entities import BaseNodeData, RetryConfig
logger = logging.getLogger(__name__)
class BaseNode:
_node_type: ClassVar[NodeType]
class Node:
node_type: ClassVar["NodeType"]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.user_from = UserFrom(graph_init_params.user_from)
self.invoke_from = InvokeFrom(graph_init_params.invoke_from)
self.workflow_call_depth = graph_init_params.call_depth
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
self.state: NodeState = NodeState.UNKNOWN # node execution state
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
self.node_id = node_id
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
def run(self) -> Generator[GraphNodeEventBase, None, None]:
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
self._start_at = naive_utc_now()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.title,
in_iteration_id=None,
start_at=self._start_at,
)
# === FIXME(-LAN-): Needs to refactor.
from core.workflow.nodes.tool.tool_node import ToolNode
if isinstance(self, ToolNode):
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
if isinstance(self, DatasourceNode):
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
start_event.provider_id = f"{plugin_id}/{provider_name}"
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from typing import cast
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.agent.entities import AgentNodeData
if isinstance(self, AgentNode):
start_event.agent_strategy = AgentNodeStrategyInit(
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
icon=self.agent_strategy_icon,
)
# ===
yield start_event
try:
result = self._run()
# Handle NodeRunResult
if isinstance(result, NodeRunResult):
yield self._convert_node_run_result_to_graph_node_event(result)
return
# Handle event stream
for event in result:
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self._node_execution_id
yield event
else:
yield event
except Exception as e:
logger.exception("Node %s failed to run", self.node_id)
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
else:
yield from result
yield NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
error=str(e),
)
@classmethod
def extract_variable_selector_to_variable_mapping(
@ -140,13 +234,21 @@ class BaseNode:
) -> Mapping[str, Sequence[str]]:
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {}
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this node blocks the output of specific variables.
@property
def type_(self) -> NodeType:
return self._node_type
This method is used to determine if a node must complete execution before
the specified variables can be used in streaming output.
:param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str'))
:return: True if this node blocks output of any of the specified variables, False otherwise
"""
return False
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {}
@classmethod
@abstractmethod
@ -158,10 +260,6 @@ class BaseNode:
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property
def continue_on_error(self) -> bool:
return False
@property
def retry(self) -> bool:
return False
@ -170,7 +268,7 @@ class BaseNode:
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
@ -185,7 +283,7 @@ class BaseNode:
...
@abstractmethod
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
"""Get the node description."""
...
@ -201,7 +299,7 @@ class BaseNode:
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> Optional[ErrorStrategy]:
def error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@ -216,7 +314,7 @@ class BaseNode:
return self._get_title()
@property
def description(self) -> Optional[str]:
def description(self) -> str | None:
"""Get the node description."""
return self._get_description()
@ -224,3 +322,198 @@ class BaseNode:
def default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
error=result.error,
)
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
)
case _:
raise Exception(f"result status {result.status} not supported")
@singledispatchmethod
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
)
@_dispatch.register
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=event.node_run_result,
)
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=event.node_run_result,
error=event.node_run_result.error,
)
case _:
raise NotImplementedError(
f"Node {self._node_id} does not support status {event.node_run_result.status}"
)
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
message_id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
)
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
metadata=event.metadata,
predecessor_node_id=event.predecessor_node_id,
)
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
index=event.index,
pre_loop_output=event.pre_loop_output,
)
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
)
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error,
)
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
metadata=event.metadata,
predecessor_node_id=event.predecessor_node_id,
)
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
index=event.index,
pre_iteration_output=event.pre_iteration_output,
)
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
)
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error,
)
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,
context=event.context,
node_version=self.version(),
)

View File

@ -0,0 +1,148 @@
"""Template structures for Response nodes (Answer and End).
This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Union
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
@dataclass(frozen=True)
class TemplateSegment(ABC):
"""Base class for template segments."""
@abstractmethod
def __str__(self) -> str:
"""String representation of the segment."""
pass
@dataclass(frozen=True)
class TextSegment(TemplateSegment):
"""A text segment in a template."""
text: str
def __str__(self) -> str:
return self.text
@dataclass(frozen=True)
class VariableSegment(TemplateSegment):
"""A variable reference segment in a template."""
selector: Sequence[str]
variable_name: str | None = None # Optional variable name for End nodes
def __str__(self) -> str:
return "{{#" + ".".join(self.selector) + "#}}"
# Type alias for segments
TemplateSegmentUnion = Union[TextSegment, VariableSegment]
@dataclass(frozen=True)
class Template:
"""Unified template structure for Response nodes.
Similar to SegmentGroup, but represents the template structure
without variable values - only marking variable selectors.
"""
segments: list[TemplateSegmentUnion]
@classmethod
def from_answer_template(cls, template_str: str) -> "Template":
"""Create a Template from an Answer node template string.
Example:
"Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])]
Args:
template_str: The answer template string
Returns:
Template instance
"""
parser = VariableTemplateParser(template_str)
segments: list[TemplateSegmentUnion] = []
# Extract variable selectors to find all variables
variable_selectors = parser.extract_variable_selectors()
var_map = {var.variable: var.value_selector for var in variable_selectors}
# Parse template to get ordered segments
# We need to split the template by variable placeholders while preserving order
import re
# Create a regex pattern that matches variable placeholders
pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}"
# Split template while keeping the delimiters (variable placeholders)
parts = re.split(pattern, template_str)
for i, part in enumerate(parts):
if not part:
continue
# Check if this part is a variable reference (odd indices after split)
if i % 2 == 1: # Odd indices are variable keys
# Remove the # symbols from the variable key
var_key = part
if var_key in var_map:
segments.append(VariableSegment(selector=list(var_map[var_key])))
else:
# This shouldn't happen with valid templates
segments.append(TextSegment(text="{{" + part + "}}"))
else:
# Even indices are text segments
segments.append(TextSegment(text=part))
return cls(segments=segments)
@classmethod
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.
Example:
[{"variable": "text", "value_selector": ["node1", "text"]},
{"variable": "result", "value_selector": ["node2", "result"]}]
->
[VariableSegment(["node1", "text"]),
TextSegment("\n"),
VariableSegment(["node2", "result"])]
Args:
outputs_config: List of output configurations with variable and value_selector
Returns:
Template instance
"""
segments: list[TemplateSegmentUnion] = []
for i, output in enumerate(outputs_config):
if i > 0:
# Add newline separator between variables
segments.append(TextSegment(text="\n"))
value_selector = output.get("value_selector", [])
variable_name = output.get("variable", "")
if value_selector:
segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name))
if len(segments) > 0 and isinstance(segments[-1], TextSegment):
segments = segments[:-1]
return cls(segments=segments)
def __str__(self) -> str:
"""String representation of the template."""
return "".join(str(segment) for segment in self.segments)

View File

@ -0,0 +1,130 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any
from .entities import VariableSelector
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]:
parts = SELECTOR_PATTERN.split(template)
selectors = []
for part in filter(lambda x: x, parts):
if "." in part and part[0] == "#" and part[-1] == "#":
selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split(".")))
return selectors
class VariableTemplateParser:
"""
!NOTE: Consider to use the new `segments` module instead of this class.
A class for parsing and manipulating template variables in a string.
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: #node_id.var1.var2#.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
Example usage:
template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}."
parser = VariableTemplateParser(template)
# Extract template variable keys
variable_keys = parser.extract()
print(variable_keys)
# Output: ['#node_id.query.name#', '#node_id.query.age#']
# Extract variable selectors
variable_selectors = parser.extract_variable_selectors()
print(variable_selectors)
# Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']),
# VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])]
# Format the template string
inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}}
formatted_string = parser.format(inputs)
print(formatted_string)
# Output: "Hello, John! Your age is 25."
"""
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
def extract(self):
"""
Extracts all the template variable keys from the template string.
Returns:
A list of template variable keys.
"""
# Regular expression to match the template rules
matches = re.findall(REGEX, self.template)
first_group_matches = [match[0] for match in matches]
return list(set(first_group_matches))
def extract_variable_selectors(self) -> list[VariableSelector]:
"""
Extracts the variable selectors from the template variable keys.
Returns:
A list of VariableSelector objects representing the variable selectors.
"""
variable_selectors = []
for variable_key in self.variable_keys:
remove_hash = variable_key.replace("#", "")
split_result = remove_hash.split(".")
if len(split_result) < 2:
continue
variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result))
return variable_selectors
def format(self, inputs: Mapping[str, Any]) -> str:
"""
Formats the template string by replacing the template variables with their corresponding values.
Args:
inputs: A dictionary containing the values for the template variables.
Returns:
The formatted string with template variables replaced by their values.
"""
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if value is None:
value = ""
# convert the value to string
if isinstance(value, list | dict | bool | int | float):
value = str(value)
# remove template variables if required
return VariableTemplateParser.remove_template_variables(value)
prompt = re.sub(REGEX, replacer, self.template)
return re.sub(r"<\|.*?\|>", "", prompt)
@classmethod
def remove_template_variables(cls, text: str):
"""
Removes the template variables from the given text.
Args:
text: The text from which to remove the template variables.
Returns:
The text with template variables removed.
"""
return re.sub(REGEX, r"{\1}", text)

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, Optional
from typing import Any, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -9,12 +9,11 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .exc import (
CodeNodeError,
@ -23,15 +22,15 @@ from .exc import (
)
class CodeNode(BaseNode):
_node_type = NodeType.CODE
class CodeNode(Node):
node_type = NodeType.CODE
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -40,7 +39,7 @@ class CodeNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -50,7 +49,7 @@ class CodeNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.
@ -58,7 +57,7 @@ class CodeNode(BaseNode):
"""
code_language = CodeLanguage.PYTHON3
if filters:
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
@ -109,8 +108,6 @@ class CodeNode(BaseNode):
"""
if value is None:
return None
if not isinstance(value, str):
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
@ -123,8 +120,6 @@ class CodeNode(BaseNode):
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
if value is None:
return None
if not isinstance(value, bool):
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
return value
@ -137,8 +132,6 @@ class CodeNode(BaseNode):
"""
if value is None:
return None
if not isinstance(value, int | float):
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
@ -161,7 +154,7 @@ class CodeNode(BaseNode):
def _transform_result(
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
output_schema: dict[str, CodeNodeData.Output] | None,
prefix: str = "",
depth: int = 1,
):
@ -262,7 +255,13 @@ class CodeNode(BaseNode):
)
elif output_config.type == SegmentType.NUMBER:
# check if number available
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
value = result.get(output_name)
if value is not None and not isinstance(value, (int, float)):
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not a number,"
f" got {type(result.get(output_name))} instead."
)
checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}")
# If the output is a boolean and the output schema specifies a NUMBER type,
# convert the boolean value to an integer.
#
@ -272,8 +271,13 @@ class CodeNode(BaseNode):
elif output_config.type == SegmentType.STRING:
# check if string available
value = result.get(output_name)
if value is not None and not isinstance(value, str):
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead"
)
transformed_result[output_name] = self._check_string(
value=result[output_name],
value=value,
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == SegmentType.BOOLEAN:
@ -283,31 +287,36 @@ class CodeNode(BaseNode):
)
elif output_config.type == SegmentType.ARRAY_NUMBER:
# check if array of number available
if not isinstance(result[output_name], list):
if result[output_name] is None:
value = result[output_name]
if not isinstance(value, list):
if value is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
)
for i, inner_value in enumerate(value):
if not isinstance(inner_value, (int, float)):
raise OutputValidationError(
f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be"
f" a number."
)
_ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
transformed_result[output_name] = [
# If the element is a boolean and the output schema specifies a `array[number]` type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
self._convert_boolean_to_int(
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
)
for i, value in enumerate(result[output_name])
self._convert_boolean_to_int(v)
for v in value
]
elif output_config.type == SegmentType.ARRAY_STRING:
# check if array of string available
@ -370,8 +379,9 @@ class CodeNode(BaseNode):
]
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
value = result[output_name]
if not isinstance(value, list):
if value is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
@ -379,10 +389,14 @@ class CodeNode(BaseNode):
f" got {type(result.get(output_name))} instead."
)
else:
transformed_result[output_name] = [
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
for i, inner_value in enumerate(value):
if inner_value is not None and not isinstance(inner_value, bool):
raise OutputValidationError(
f"Output {prefix}{dot}{output_name}[{i}] is not a boolean,"
f" got {type(inner_value)} instead."
)
_ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
transformed_result[output_name] = value
else:
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
@ -403,6 +417,7 @@ class CodeNode(BaseNode):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
@ -411,10 +426,6 @@ class CodeNode(BaseNode):
for variable_selector in typed_node_data.variables
}
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -1,11 +1,11 @@
from typing import Annotated, Literal, Optional
from typing import Annotated, Literal, Self
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: Optional[dict[str, "CodeNodeData.Output"]] = None
children: dict[str, Self] | None = None
class Dependency(BaseModel):
name: str
@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData):
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
code: str
outputs: dict[str, Output]
dependencies: Optional[list[Dependency]] = None
dependencies: list[Dependency] | None = None

View File

@ -0,0 +1,3 @@
from .datasource_node import DatasourceNode
__all__ = ["DatasourceNode"]

View File

@ -0,0 +1,502 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceParameter,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError
class DatasourceNode(Node):
"""
Datasource Node
"""
_node_data: DatasourceNodeData
node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DatasourceNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> Generator:
"""
Run the datasource node
"""
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
if not datasource_type_segement:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if not datasource_info_segement:
raise DatasourceNodeError("Datasource info is not set")
datasource_info_value = datasource_info_segement.value
if not isinstance(datasource_info_value, dict):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
# get datasource runtime
from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
parameters_for_log = datasource_info
try:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=datasource_info.get("credential_id", ""),
)
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
if credentials:
datasource_runtime.runtime.credentials = credentials
online_document_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
),
provider_type=datasource_type,
)
)
yield from self._transform_message(
messages=online_document_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
)
case DatasourceProviderType.ONLINE_DRIVE:
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
if credentials:
datasource_runtime.runtime.credentials = credentials
online_drive_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.online_drive_download_file(
user_id=self.user_id,
request=OnlineDriveDownloadFileRequest(
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket"),
),
provider_type=datasource_type,
)
)
yield from self._transform_datasource_file_message(
messages=online_drive_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
variable_pool=variable_pool,
datasource_type=datasource_type,
)
case DatasourceProviderType.WEBSITE_CRAWL:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
**datasource_info,
"datasource_type": datasource_type,
},
)
)
case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,
)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file_info,
"datasource_type": datasource_type,
},
)
)
case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
)
except DatasourceNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
)
def _generate_parameters(
self,
*,
datasource_parameters: Sequence[DatasourceParameter],
variable_pool: VariablePool,
node_data: DatasourceNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {}
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
datasource_input = node_data.datasource_parameters[parameter_name]
if datasource_input.type == "variable":
variable = variable_pool.get(datasource_input.value)
if variable is None:
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value
elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
typed_node_data = DatasourceNodeData.model_validate(node_data)
result = {}
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
return result
def _transform_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={**variables},
metadata={
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
},
inputs=parameters_for_log,
)
)
@classmethod
def version(cls) -> str:
return "1"
def _transform_datasource_file_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: VariablePool,
datasource_type: DatasourceProviderType,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
file = None
for message in message_stream:
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
if file:
variable_pool.add([self._node_id, "file"], file)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file,
"datasource_type": datasource_type,
},
)
)

View File

@ -0,0 +1,41 @@
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.workflow.nodes.base.entities import BaseNodeData
class DatasourceEntity(BaseModel):
plugin_id: str
provider_name: str # redundancy
provider_type: str
datasource_name: str | None = "local_file"
datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"] | None = None
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
datasource_parameters: dict[str, DatasourceInput] | None = None

View File

@ -0,0 +1,16 @@
class DatasourceNodeError(ValueError):
"""Base exception for datasource node errors."""
pass
class DatasourceParameterError(DatasourceNodeError):
"""Exception raised for errors in datasource parameters."""
pass
class DatasourceFileError(DatasourceNodeError):
"""Exception raised for errors related to datasource files."""
pass

View File

@ -5,7 +5,7 @@ import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any
import chardet
import docx
@ -25,11 +25,10 @@ from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
@ -37,20 +36,20 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode):
class DocumentExtractorNode(Node):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
_node_type = NodeType.DOCUMENT_EXTRACTOR
node_type = NodeType.DOCUMENT_EXTRACTOR
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -59,7 +58,7 @@ class DocumentExtractorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -428,9 +427,9 @@ def _download_file_content(file: File) -> bytes:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return cast(bytes, response.content)
return response.content
else:
return cast(bytes, file_manager.download(file))
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e

View File

@ -1,4 +0,0 @@
from .end_node import EndNode
from .entities import EndStreamParam
__all__ = ["EndNode", "EndStreamParam"]

View File

@ -1,23 +1,24 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import ErrorStrategy, NodeType
class EndNode(BaseNode):
_node_type = NodeType.END
class EndNode(Node):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -26,7 +27,7 @@ class EndNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -41,8 +42,10 @@ class EndNode(BaseNode):
def _run(self) -> NodeRunResult:
"""
Run node
:return:
Run node - collect all outputs at once.
This method runs after streaming is complete (if streaming was enabled).
It collects all output variables and returns them.
"""
output_variables = self._node_data.outputs
@ -57,3 +60,15 @@ class EndNode(BaseNode):
inputs=outputs,
outputs=outputs,
)
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this End node
"""
outputs_config = [
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
]
return Template.from_end_outputs(outputs_config)

View File

@ -1,152 +0,0 @@
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
from core.workflow.nodes.enums import NodeType
class EndStreamGeneratorRouter:
@classmethod
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str],
) -> EndStreamParam:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selector of end nodes
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
for end_node_id, node_config in node_id_config_mapping.items():
if node_config.get("data", {}).get("type") != NodeType.END.value:
continue
# skip end node in parallel
if end_node_id in node_parallel_mapping:
continue
# get generate route for stream output
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
# fetch end dependencies
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
end_dependencies = cls._fetch_ends_dependencies(
end_node_ids=end_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping,
)
return EndStreamParam(
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
end_dependencies=end_dependencies,
)
@classmethod
def extract_stream_variable_selector_from_node_data(
cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData
) -> list[list[str]]:
"""
Extract stream variable selector from node data
:param node_id_config_mapping: node id config mapping
:param node_data: node data object
:return:
"""
variable_selectors = node_data.outputs
value_selectors = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != "sys" and node_id in node_id_config_mapping:
node = node_id_config_mapping[node_id]
node_type = node.get("data", {}).get("type")
if (
variable_selector.value_selector not in value_selectors
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == "text"
):
value_selectors.append(list(variable_selector.value_selector))
return value_selectors
@classmethod
def _extract_stream_variable_selector(
cls, node_id_config_mapping: dict[str, dict], config: dict
) -> list[list[str]]:
"""
Extract stream variable selector from node config
:param node_id_config_mapping: node id config mapping
:param config: node config
:return:
"""
node_data = EndNodeData(**config.get("data", {}))
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
@classmethod
def _fetch_ends_dependencies(
cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch end dependencies
:param end_node_ids: end node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
end_dependencies: dict[str, list[str]] = {}
for end_node_id in end_node_ids:
if end_dependencies.get(end_node_id) is None:
end_dependencies[end_node_id] = []
cls._recursive_fetch_end_dependencies(
current_node_id=end_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies,
)
return end_dependencies
@classmethod
def _recursive_fetch_end_dependencies(
cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch end dependencies
:param current_node_id: current node id
:param end_node_id: end node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param end_dependencies: end dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
}:
end_dependencies[end_node_id].append(source_node_id)
else:
cls._recursive_fetch_end_dependencies(
current_node_id=source_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies,
)

View File

@ -1,188 +0,0 @@
import logging
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.end_stream_param = graph.end_stream_param
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_output = False
self.output_node_ids: set[str] = set()
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id or event.in_loop_id:
if self.has_output and event.node_id not in self.output_node_ids:
event.chunk_content = "\n" + event.chunk_content
self.output_node_ids.add(event.node_id)
self.has_output = True
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_end_node_ids
)
if stream_out_end_node_ids:
if self.has_output and event.node_id not in self.output_node_ids:
event.chunk_content = "\n" + event.chunk_content
self.output_node_ids.add(event.node_id)
self.has_output = True
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[end_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for end_node_id, position in self.route_position.items():
# all depends on end node id not in rest node ids
if event.route_node_state.node_id != end_node_id and (
end_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]
)
):
continue
route_position = self.route_position[end_node_id]
position = 0
value_selectors = []
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position >= route_position:
value_selectors.append(current_value_selectors)
position += 1
for value_selector in value_selectors:
if not value_selector:
continue
value = self.variable_pool.get(value_selector)
if value is None:
break
text = value.markdown
if text:
current_node_id = value_selector[0]
if self.has_output and current_node_id not in self.output_node_ids:
text = "\n" + text
self.output_node_ids.add(current_node_id)
self.has_output = True
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[end_node_id] += 1
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_end_node_ids = []
for end_node_id, route_position in self.route_position.items():
if end_node_id not in self.rest_node_ids:
continue
# all depends on end node id not in rest node ids
if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
continue
position = 0
value_selector = None
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position == route_position:
value_selector = current_value_selectors
break
position += 1
if not value_selector:
continue
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_end_node_ids.append(end_node_id)
return stream_out_end_node_ids

View File

@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class EndNodeData(BaseNodeData):

View File

@ -1,49 +0,0 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
TRIGGER_WEBHOOK = "trigger-webhook"
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
@property
def is_start_node(self) -> bool:
return self in [
NodeType.START,
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"

View File

@ -1,17 +0,0 @@
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent
__all__ = [
"ModelInvokeCompletedEvent",
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent",
]

View File

@ -1,40 +0,0 @@
from collections.abc import Sequence
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult
class RunCompletedEvent(BaseModel):
run_result: NodeRunResult = Field(..., description="run result")
class RunStreamChunkEvent(BaseModel):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class ModelInvokeCompletedEvent(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: str | None = None
class RunRetryEvent(BaseModel):
"""Node Run Retry event"""
error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")

View File

@ -1,3 +0,0 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent

View File

@ -1,7 +1,7 @@
import mimetypes
from collections.abc import Sequence
from email.message import Message
from typing import Any, Literal, Optional
from typing import Any, Literal
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel):
class HttpRequestNodeAuthorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[HttpRequestNodeAuthorizationConfig] = None
config: HttpRequestNodeAuthorizationConfig | None = None
@field_validator("config", mode="before")
@classmethod
@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData):
authorization: HttpRequestNodeAuthorization
headers: str
params: str
body: Optional[HttpRequestNodeBody] = None
timeout: Optional[HttpRequestNodeTimeout] = None
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
body: HttpRequestNodeBody | None = None
timeout: HttpRequestNodeTimeout | None = None
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
class Response:
@ -183,7 +183,7 @@ class Response:
return f"{(self.size / 1024 / 1024):.2f} MB"
@property
def parsed_content_disposition(self) -> Optional[Message]:
def parsed_content_disposition(self) -> Message | None:
content_disposition = self.headers.get("content-disposition", "")
if content_disposition:
msg = Message()

View File

@ -15,7 +15,7 @@ from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import VariablePool
from .entities import (
HttpRequestNodeAuthorization,
@ -263,9 +263,6 @@ class Executor:
if authorization.config is None:
raise AuthorizationConfigError("authorization config is required")
if self.auth.config.api_key is None:
raise AuthorizationConfigError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
@ -409,30 +406,25 @@ class Executor:
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
for file_entry in self.files:
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2:
if len(file_entry) != 2 or len(file_entry[1]) < 2:
continue # skip malformed entries
key = file_entry[0]
content = file_entry[1][1]
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
# decode content safely
if isinstance(content, bytes):
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
elif isinstance(content, str):
body_string += content
else:
body_string += f"[Unsupported content type: {type(content).__name__}]"
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
body_string += "\r\n"
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, str):
body_string = self.content
elif isinstance(self.content, bytes):
if isinstance(self.content, bytes):
body_string = self.content.decode("utf-8", errors="replace")
else:
body_string = self.content
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body_string = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":

View File

@ -1,20 +1,18 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from factories import file_factory
from .entities import (
@ -33,15 +31,15 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode):
_node_type = NodeType.HTTP_REQUEST
class HttpRequestNode(Node):
node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -50,7 +48,7 @@ class HttpRequestNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -60,7 +58,7 @@ class HttpRequestNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "http-request",
"config": {
@ -101,7 +99,7 @@ class HttpRequestNode(BaseNode):
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and (self.continue_on_error or self.retry):
if not response.response.is_success and (self.error_strategy or self.retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@ -129,7 +127,7 @@ class HttpRequestNode(BaseNode):
},
)
except HttpRequestNodeError as e:
logger.warning("http request node %s failed to run: %s", self.node_id, e)
logger.warning("http request node %s failed to run: %s", self._node_id, e)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
@ -244,10 +242,6 @@ class HttpRequestNode(BaseNode):
return ArrayFileSegment(value=files)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData):
logical_operator: Literal["and", "or"]
conditions: list[Condition]
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
cases: Optional[list[Case]] = None
cases: list[Case] | None = None

View File

@ -1,28 +1,28 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
from typing_extensions import deprecated
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.entities import VariablePool
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(BaseNode):
_node_type = NodeType.IF_ELSE
class IfElseNode(Node):
node_type = NodeType.IF_ELSE
execution_type = NodeExecutionType.BRANCH
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -31,7 +31,7 @@ class IfElseNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -49,13 +49,13 @@ class IfElseNode(BaseNode):
Run node
:return:
"""
node_inputs: dict[str, list] = {"conditions": []}
node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []}
process_data: dict[str, list] = {"condition_results": []}
input_conditions = []
input_conditions: Sequence[Mapping[str, Any]] = []
final_result = False
selected_case_id = None
selected_case_id = "false"
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
@ -83,7 +83,7 @@ class IfElseNode(BaseNode):
else:
# TODO: Update database then remove this
# Fallback to old structure if cases are not defined
input_conditions, group_result, final_result = _should_not_use_old_function(
input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated]
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self._node_data.conditions or [],

View File

@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData):
Iteration Node Data.
"""
parent_loop_id: Optional[str] = None # redundant field, not used currently
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
current_output: Any = None
class MetaData(BaseIterationState.MetaData):
"""
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
iterator_length: int
def get_last_output(self) -> Optional[Any]:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@ -1,48 +1,40 @@
import contextvars
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
from datetime import datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import (
NodeRunResult,
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
InNodeEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
)
from core.workflow.node_events import (
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@ -54,23 +46,26 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(BaseNode):
class IterationNode(Node):
"""
Iteration Node.
"""
_node_type = NodeType.ITERATION
node_type = NodeType.ITERATION
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -79,7 +74,7 @@ class IterationNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -89,7 +84,7 @@ class IterationNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "iteration",
"config": {
@ -103,225 +98,266 @@ class IterationNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
variable = self._get_iterator_variable()
if self._is_empty_iteration(variable):
yield from self._handle_empty_iteration(variable)
return
iterator_list_value = self._validate_and_get_iterator_list(variable)
inputs = {"iterator_selector": iterator_list_value}
self._validate_start_node()
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[object] = []
yield IterationStartedEvent(
start_at=started_at,
inputs=inputs,
metadata={"iteration_length": len(iterator_list_value)},
)
try:
yield from self._execute_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
)
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
)
except IterationNodeError as e:
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
error=e,
)
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
return
return variable
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
return isinstance(variable, NoneSegment) or len(variable.value) == 0
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
# Try our best to preserve the type information.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None:
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
root_node_id = self._node_data.start_node_id
def _execute_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
# Parallel mode execution
yield from self._execute_parallel_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
)
else:
# Sequential mode execution
for index, item in enumerate(iterator_list_value):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
yield IterationNextEvent(index=index)
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
graph_engine = self._create_graph_engine(index, item)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool
variable_pool.add([self.node_id, "index"], 0)
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
)
start_at = naive_utc_now()
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)},
predecessor_node_id=self.previous_node_id,
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=0,
pre_iteration_output=None,
duration=None,
)
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self._node_data.is_parallel:
futures: list[Future] = []
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
# Run the iteration
yield from self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs,
graph_engine=graph_engine,
)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
flask_app=current_app._get_current_object(), # type: ignore
q=q,
context=contextvars.copy_context(),
iterator_list_value=iterator_list_value,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
index=index,
item=item,
iter_run_map=iter_run_map,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
if isinstance(event, IterationRunNextEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)
yield event
if isinstance(event, RunCompletedEvent):
q.put(None)
for f in futures:
if not f.done():
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value))
# Determine the number of parallel workers
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
for index, item in enumerate(iterator_list_value):
yield IterationNextEvent(index=index)
future = executor.submit(
self._execute_single_iteration_parallel,
index=index,
item=item,
)
future_to_index[future] = index
# Process completed iterations as they finish
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
result = future.result()
iter_start_at, events, output_value, tokens_used = result
# Update outputs at the correct index
outputs[index] = output_value
# Yield all events from this iteration
yield from events
# Update tokens and timing
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
except Exception as e:
# Handle errors based on error_handle_mode
match self._node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
# Cancel remaining futures and re-raise
for f in future_to_index:
if f != future:
f.cancel()
yield event
if isinstance(event, IterationRunFailedEvent):
q.put(None)
yield event
except Empty:
continue
raise IterationNodeError(str(e))
case ErrorHandleMode.CONTINUE_ON_ERROR:
outputs[index] = None
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[index] = None # Will be filtered later
# wait all threads
wait(futures)
else:
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[:] = [output for output in outputs if output is not None]
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
output_segment = build_segment(outputs)
def _execute_single_iteration_parallel(
self,
index: int,
item: object,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
"""Execute a single iteration in parallel mode and return results."""
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
graph_engine = self._create_graph_engine(index, item)
# Collect events instead of yielding them directly
for event in self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs_temp,
graph_engine=graph_engine,
):
events.append(event)
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
def _handle_iteration_success(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
) -> Generator[NodeEventBase, None, None]:
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
# Yield final success event
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
},
)
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": output_segment},
metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
},
)
def _handle_iteration_failure(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
)
except IterationNodeError as e:
# iteration run failed
logger.warning("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
finally:
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, "index"])
variable_pool.remove([self.node_id, "item"])
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -337,14 +373,20 @@ class IterationNode(BaseNode):
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
iteration_node_ids = set()
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
continue
@ -376,303 +418,115 @@ class IterationNode(BaseNode):
variable_mapping.update(sub_node_variable_mapping)
# remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
return variable_mapping
def _handle_event_metadata(
def _append_iteration_info_to_event(
self,
*,
event: BaseNodeEvent | InNodeEvent,
event: GraphNodeEventBase,
iter_run_index: int,
parallel_mode_run_id: str | None,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
"""
if not isinstance(event, BaseNodeEvent):
return event
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
):
event.in_iteration_id = self._node_id
iter_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
}
if parallel_mode_run_id:
# for parallel, the specific branch ID is more important than the sequential index
iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
if event.route_node_state.node_run_result:
current_metadata = event.route_node_state.node_run_result.metadata or {}
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
return event
current_metadata = event.node_run_result.metadata
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **iter_metadata}
def _run_single_iter(
self,
*,
iterator_list_value: Sequence[str],
variable_pool: VariablePool,
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
outputs: list[object],
graph_engine: "GraphEngine",
iteration_graph: Graph,
iter_run_map: dict[str, float],
parallel_mode_run_id: Optional[str] = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
iter_start_at = naive_utc_now()
) -> Generator[GraphNodeEventBase, None, None]:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self._node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
current_index = index_variable.value
for event in rst:
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
continue
try:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
current_index = index_variable.value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector)
if result is None:
outputs.append(None)
else:
outputs.append(result.to_object())
return
elif isinstance(event, GraphRunFailedEvent):
match self._node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
raise IterationNodeError(event.error)
case ErrorHandleMode.CONTINUE_ON_ERROR:
outputs.append(None)
return
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
return
elif isinstance(event, InNodeEvent):
# event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs[current_index] = None
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs[current_index] = None
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
# clean nodes resources
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from.value,
invoke_from=self.invoke_from.value,
call_depth=self.workflow_call_depth,
)
# Create a deep copy of the variable pool for each iteration
variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True)
# iteration run failed
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
# append iteration variable (item, index) to variable pool
variable_pool_copy.add([self._node_id, "index"], index)
variable_pool_copy.add([self._node_id, "item"], item)
# stop the iterator
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
yield metadata_event
# Create a new GraphRuntimeState for this iteration
graph_runtime_state_copy = GraphRuntimeState(
variable_pool=variable_pool_copy,
start_at=self.graph_runtime_state.start_at,
total_tokens=0,
node_run_steps=0,
)
current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
outputs[current_index] = current_iteration_output
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# move to next iteration
variable_pool.add([self.node_id, "index"], next_index)
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None,
duration=duration,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
except IterationNodeError as e:
logger.warning("Iteration run failed:%s", str(e))
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=iteration_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
def _run_single_iter_parallel(
self,
*,
flask_app: Flask,
context: contextvars.Context,
q: Queue,
iterator_list_value: Sequence[str],
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
index: int,
item: Any,
iter_run_map: dict[str, float],
):
"""
run single iteration in parallel mode
"""
with preserve_flask_contexts(flask_app, context_vars=context):
parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy()
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
variable_pool_copy.add([self.node_id, "index"], index)
variable_pool_copy.add([self.node_id, "item"], item)
for event in self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool_copy,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)
graph_engine.graph_runtime_state.total_tokens += graph_engine_copy.graph_runtime_state.total_tokens
return graph_engine

View File

@ -1,27 +1,26 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(BaseNode):
class IterationStartNode(Node):
"""
Iteration Start Node.
"""
_node_type = NodeType.ITERATION_START
node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +29,7 @@ class IterationStartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -0,0 +1,3 @@
from .knowledge_index_node import KnowledgeIndexNode
__all__ = ["KnowledgeIndexNode"]

View File

@ -0,0 +1,159 @@
from typing import Literal, Union
from pydantic import BaseModel
from core.workflow.nodes.base import BaseNodeData
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str
reranking_model_name: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class FileInfo(BaseModel):
"""
File Info.
"""
file_id: str
class OnlineDocumentIcon(BaseModel):
"""
Document Icon.
"""
icon_url: str
icon_type: str
icon_emoji: str
class OnlineDocumentInfo(BaseModel):
"""
Online document info.
"""
provider: str
workspace_id: str | None = None
page_id: str
page_type: str
icon: OnlineDocumentIcon | None = None
class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
url: str
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunks: list[str]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_contents: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class KnowledgeIndexNodeData(BaseNodeData):
"""
Knowledge index Node Data.
"""
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]

View File

@ -0,0 +1,22 @@
class KnowledgeIndexNodeError(ValueError):
"""Base class for KnowledgeIndexNode errors."""
class ModelNotExistError(KnowledgeIndexNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeIndexNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeIndexNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeIndexNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -0,0 +1,209 @@
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
)
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class KnowledgeIndexNode(Node):
_node_data: KnowledgeIndexNodeData
node_type = NodeType.KNOWLEDGE_INDEX
execution_type = NodeExecutionType.RESPONSE
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeIndexNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeIndexNodeData, self._node_data)
variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:
raise KnowledgeIndexNodeError("Dataset ID is required.")
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
# extract variables
variable = variable_pool.get(node_data.index_chunk_variable_selector)
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
else:
is_preview = False
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
# index knowledge
try:
if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=outputs,
)
results = self._invoke_knowledge_index(
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _invoke_knowledge_index(
self,
dataset: Dataset,
node_data: KnowledgeIndexNodeData,
chunks: Mapping[str, Any],
variable_pool: VariablePool,
) -> Any:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
raise KnowledgeIndexNodeError("Document ID is required.")
original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
document = db.session.query(Document).filter_by(id=document_id.value).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
if original_document_id:
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id.value)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
index_processor.index(dataset, document, chunks)
indexing_end_at = time.perf_counter()
document.indexing_latency = indexing_end_at - indexing_start_at
# update document status
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
db.session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
)
.scalar()
)
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
)
db.session.commit()
return {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"batch": batch.value,
"document_id": document.id,
"document_name": document.name,
"created_at": document.created_at.timestamp(),
"display_status": document.indexing_status,
}
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks)
@classmethod
def version(cls) -> str:
return "1"
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this knowledge index node
"""
return Template(segments=[])

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel):
"""
top_k: int
score_threshold: Optional[float] = None
score_threshold: float | None = None
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class SingleRetrievalConfig(BaseModel):
@ -91,7 +91,7 @@ SupportedComparisonOperator = Literal[
class Condition(BaseModel):
"""
Conditon detail
Condition detail
"""
name: str
@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel):
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
multiple_retrieval_config: MultipleRetrievalConfig | None = None
single_retrieval_config: SingleRetrievalConfig | None = None
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property

View File

@ -4,9 +4,9 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import sessionmaker
@ -32,14 +32,11 @@ from core.variables import (
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
)
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
@ -70,7 +67,7 @@ from .exc import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
logger = logging.getLogger(__name__)
@ -83,8 +80,8 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(BaseNode):
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
class KnowledgeRetrievalNode(Node):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
@ -99,21 +96,15 @@ class KnowledgeRetrievalNode(BaseNode):
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
@ -125,10 +116,10 @@ class KnowledgeRetrievalNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -137,7 +128,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -197,7 +188,7 @@ class KnowledgeRetrievalNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
process_data={},
outputs=outputs, # type: ignore
)
@ -259,7 +250,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
@ -291,7 +282,7 @@ class KnowledgeRetrievalNode(BaseNode):
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
@ -367,15 +358,12 @@ class KnowledgeRetrievalNode(BaseNode):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
@ -422,14 +410,14 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
filters = [] # type: ignore
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None
@ -443,7 +431,7 @@ class KnowledgeRetrievalNode(BaseNode):
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters, # type: ignore
filters,
)
conditions.append(
Condition(
@ -514,7 +502,8 @@ class KnowledgeRetrievalNode(BaseNode):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
@ -552,7 +541,8 @@ class KnowledgeRetrievalNode(BaseNode):
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
@ -573,15 +563,15 @@ class KnowledgeRetrievalNode(BaseNode):
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
except Exception:
return []
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
):
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return
return filters
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
@ -666,6 +656,7 @@ class KnowledgeRetrievalNode(BaseNode):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)

View File

@ -1,14 +1,13 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Optional, TypeAlias, TypeVar
from typing import Any, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
@ -36,15 +35,15 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
return wrapper
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
class ListOperatorNode(Node):
node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -53,7 +52,7 @@ class ListOperatorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -67,8 +66,8 @@ class ListOperatorNode(BaseNode):
return "1"
def _run(self):
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
inputs: dict[str, Sequence[object]] = {}
process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
@ -171,27 +170,23 @@ class ListOperatorNode(BaseNode):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayBooleanSegment):
else:
if not isinstance(condition.value, bool):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statment should be unreachable.")
return variable
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
else:
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statement should be unreachable")
return variable
@ -305,7 +300,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
if key in {"type", "transfer_method"} and isinstance(value, Sequence):
if key in {"type", "transfer_method"}:
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):

View File

@ -1,12 +1,12 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class ModelConfig(BaseModel):
@ -18,7 +18,7 @@ class ModelConfig(BaseModel):
class ContextConfig(BaseModel):
enabled: bool
variable_selector: Optional[list[str]] = None
variable_selector: list[str] | None = None
class VisionConfigOptions(BaseModel):
@ -51,23 +51,40 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None
jinja2_text: str | None = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: Optional[str] = None
jinja2_text: str | None = None
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
# Keep tagged as default for backward compatibility
default="tagged",
description=(
"""
Strategy for handling model reasoning output.
separated: Return clean text (without <think> tags) + reasoning_content field.
Recommended for new workflows. Enables safe downstream parsing and
workflow variable access: {{#node_id.reasoning_content#}}
tagged : Return original text (with <think> tags) + reasoning_content field.
Maintains full backward compatibility while still providing reasoning_content
for workflow automation. Frontend thinking panels work as before.
"""
),
)
@field_validator("prompt_config", mode="before")
@classmethod

View File

@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError):
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
def __init__(self, *, type_name: str):
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -8,7 +8,7 @@ from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
from extensions.ext_database import db as global_db
class LLMFileSaver(tp.Protocol):

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Optional, cast
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
@ -13,16 +13,16 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
@ -107,7 +107,7 @@ def fetch_memory(
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration

View File

@ -2,8 +2,9 @@ import base64
import io
import json
import logging
import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Literal
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -50,22 +51,25 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
NodeEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunStreamChunkEvent,
from core.workflow.entities import GraphInitParams, VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.node_events import (
ModelInvokeCompletedEvent,
NodeEventBase,
NodeRunResult,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from . import llm_utils
from .entities import (
@ -88,17 +92,19 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.entities import GraphRuntimeState
logger = logging.getLogger(__name__)
class LLMNode(BaseNode):
_node_type = NodeType.LLM
class LLMNode(Node):
node_type = NodeType.LLM
_node_data: LLMNodeData
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@ -110,21 +116,15 @@ class LLMNode(BaseNode):
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
@ -136,10 +136,10 @@ class LLMNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -148,7 +148,7 @@ class LLMNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -161,12 +161,13 @@ class LLMNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = None
variable_pool = self.graph_runtime_state.variable_pool
try:
@ -182,8 +183,6 @@ class LLMNode(BaseNode):
# merge inputs
inputs.update(jinja_inputs)
node_inputs = {}
# fetch files
files = (
llm_utils.fetch_files(
@ -201,9 +200,8 @@ class LLMNode(BaseNode):
generator = self._fetch_context(node_data=self._node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield event
context = event.context
yield event
if context:
node_inputs["#context#"] = context
@ -221,7 +219,7 @@ class LLMNode(BaseNode):
model_instance=model_instance,
)
query = None
query: str | None = None
if self._node_data.memory:
query = self._node_data.memory.query_prompt_template
if not query and (
@ -255,18 +253,31 @@ class LLMNode(BaseNode):
structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self._node_data.reasoning_format,
)
structured_output: LLMStructuredOutput | None = None
for event in generator:
if isinstance(event, RunStreamChunkEvent):
if isinstance(event, StreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompletedEvent):
# Raw text
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
reasoning_content = event.reasoning_content or ""
# For downstream nodes, determine clean text based on reasoning_format
if self._node_data.reasoning_format == "tagged":
# Keep <think> tags for backward compatibility
clean_text = result_text
else:
# Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
@ -284,14 +295,26 @@ class LLMNode(BaseNode):
"model_name": model_config.model,
}
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
outputs = {
"text": clean_text,
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs is not None:
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
yield RunCompletedEvent(
run_result=NodeRunResult(
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
@ -305,8 +328,8 @@ class LLMNode(BaseNode):
)
)
except ValueError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
@ -316,8 +339,8 @@ class LLMNode(BaseNode):
)
except Exception as e:
logger.exception("error while executing llm node")
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
@ -331,14 +354,16 @@ class LLMNode(BaseNode):
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Optional[Mapping[str, Any]] = None,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
)
@ -374,6 +399,8 @@ class LLMNode(BaseNode):
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
node_type=node_type,
reasoning_format=reasoning_format,
)
@staticmethod
@ -383,13 +410,16 @@ class LLMNode(BaseNode):
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
event = LLMNode.handle_blocking_result(
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
reasoning_format=reasoning_format,
)
yield event
return
@ -414,7 +444,11 @@ class LLMNode(BaseNode):
file_outputs=file_outputs,
):
full_text_buffer.write(text_part)
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=text_part,
is_final=False,
)
# Update the whole metadata
if not model and result.model:
@ -430,13 +464,66 @@ class LLMNode(BaseNode):
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
# Extract reasoning content from <think> tags in the main text
full_text = full_text_buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
yield ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=usage,
finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
)
@staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@classmethod
def _split_reasoning(
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
) -> tuple[str, str]:
"""
Split reasoning content from text based on reasoning_format strategy.
Args:
text: Full text that may contain <think> blocks
reasoning_format: Strategy for handling reasoning content
- "separated": Remove <think> tags and return clean text + reasoning_content field
- "tagged": Keep <think> tags in text, return empty reasoning_content
Returns:
tuple of (clean_text, reasoning_content)
"""
if reasoning_format == "tagged":
return text, ""
# Find all <think>...</think> blocks (case-insensitive)
matches = cls._THINK_PATTERN.findall(text)
# Extract reasoning content from all <think> blocks
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
# Remove all <think>...</think> blocks from original text
clean_text = cls._THINK_PATTERN.sub("", text)
# Clean up extra whitespace
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@ -629,7 +716,7 @@ class LLMNode(BaseNode):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
@ -811,14 +898,14 @@ class LLMNode(BaseNode):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
):
if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt.text)
@ -872,7 +959,7 @@ class LLMNode(BaseNode):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "llm",
"config": {
@ -900,7 +987,7 @@ class LLMNode(BaseNode):
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
@ -964,6 +1051,7 @@ class LLMNode(BaseNode):
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
@ -973,10 +1061,24 @@ class LLMNode(BaseNode):
):
buffer.write(text_part)
# Extract reasoning content from <think> tags in the main text
full_text = buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
return ModelInvokeCompletedEvent(
text=buffer.getvalue(),
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=invoke_result.usage,
finish_reason=None,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
)
@staticmethod
@ -1052,7 +1154,7 @@ class LLMNode(BaseNode):
return
if isinstance(contents, str):
yield contents
elif isinstance(contents, list):
else:
for item in contents:
if isinstance(item, TextPromptMessageContent):
yield item.data
@ -1066,13 +1168,6 @@ class LLMNode(BaseNode):
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
else:
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
@ -1080,7 +1175,7 @@ class LLMNode(BaseNode):
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
@ -1089,7 +1184,8 @@ def _combine_message_content_with_role(
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
raise NotImplementedError(f"Role {role} is not supported")
case _:
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
@ -1185,7 +1281,7 @@ def _handle_memory_completion_mode(
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:

View File

@ -1,7 +1,6 @@
from collections.abc import Mapping
from typing import Annotated, Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field
from pydantic import AfterValidator, BaseModel, Field, field_validator
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
@ -35,19 +34,22 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None
value: Any | list[str] | None = None
class LoopNodeData(BaseLoopNodeData):
"""
Loop Node Data.
"""
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
outputs: Optional[Mapping[str, Any]] = None
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, v):
if v is None:
return {}
return v
class LoopStartNodeData(BaseNodeData):
@ -72,7 +74,7 @@ class LoopState(BaseLoopState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
current_output: Any = None
class MetaData(BaseLoopState.MetaData):
"""
@ -81,7 +83,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Optional[Any]:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@ -89,7 +91,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@ -1,27 +1,26 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(BaseNode):
class LoopEndNode(Node):
"""
Loop End Node.
"""
_node_type = NodeType.LOOP_END
node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +29,7 @@ class LoopEndNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,63 +1,58 @@
import contextlib
import json
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import (
IntegerSegment,
Segment,
SegmentType,
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
InNodeEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.node_events import (
LoopFailedEvent,
LoopNextEvent,
LoopStartedEvent,
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
class LoopNode(BaseNode):
class LoopNode(Node):
"""
Loop Node.
"""
_node_type = NodeType.LOOP
node_type = NodeType.LOOP
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -66,7 +61,7 @@ class LoopNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -79,7 +74,7 @@ class LoopNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def _run(self) -> Generator:
"""Run the node."""
# Get inputs
loop_count = self._node_data.loop_count
@ -89,144 +84,130 @@ class LoopNode(BaseNode):
inputs = {"loop_count": loop_count}
if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
# Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
root_node_id = self._node_data.start_node_id
# Initialize variable pool
variable_pool = self.graph_runtime_state.variable_pool
variable_pool.add([self.node_id, "index"], 0)
# Initialize loop variables
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
if self._node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
if isinstance(var.value, list)
else None,
}
for loop_variable in self._node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
}
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
)
processed_segment = value_processor[loop_variable.value_type]()
processed_segment = value_processor[loop_variable.value_type](loop_variable)
if not processed_segment:
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self.node_id, loop_variable.label]
variable_pool.add(variable_selector, processed_segment.value)
variable_selector = [self._node_id, loop_variable.label]
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
)
start_at = naive_utc_now()
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
# Start Loop event
yield LoopRunStartedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopStartedEvent(
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
predecessor_node_id=self.previous_node_id,
)
# yield LoopRunNextEvent(
# loop_id=self.id,
# loop_node_id=self.node_id,
# loop_node_type=self.node_type,
# loop_node_data=self.node_data,
# index=0,
# pre_loop_output=None,
# )
loop_duration_map = {}
single_loop_variable_map = {} # single loop variable output
try:
check_break_result = False
for i in range(loop_count):
loop_start_time = naive_utc_now()
# run single loop
loop_result = yield from self._run_single_loop(
graph_engine=graph_engine,
loop_graph=loop_graph,
variable_pool=variable_pool,
loop_variable_selectors=loop_variable_selectors,
break_conditions=break_conditions,
logical_operator=logical_operator,
condition_processor=condition_processor,
current_index=i,
start_at=start_at,
inputs=inputs,
)
loop_end_time = naive_utc_now()
reach_break_condition = False
if break_conditions:
with contextlib.suppress(ValueError):
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if reach_break_condition:
loop_count = 0
cost_tokens = 0
for i in range(loop_count):
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now()
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
# Track loop duration
loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds()
# Accumulate outputs from the sub-graph's response nodes
for key, value in graph_engine.graph_runtime_state.outputs.items():
if key == "answer":
# Concatenate answer outputs with newline
existing_answer = self.graph_runtime_state.get_output("answer", "")
if existing_answer:
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
else:
self.graph_runtime_state.set_output("answer", value)
else:
# For other outputs, just update
self.graph_runtime_state.set_output(key, value)
# Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Collect loop variable values after iteration
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
item = variable_pool.get(selector)
if item:
single_loop_variable[key] = item.value
else:
single_loop_variable[key] = None
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
single_loop_variable_map[str(i)] = single_loop_variable
check_break_result = loop_result.get("check_break_result", False)
if check_break_result:
if reach_break_node:
break
if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if reach_break_condition:
break
yield LoopNextEvent(
index=i + 1,
pre_loop_output=self._node_data.outputs,
)
self.graph_runtime_state.total_tokens += cost_tokens
# Loop completed successfully
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopSucceededEvent(
start_at=start_at,
inputs=inputs,
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "loop_break" if check_break_result else "loop_completed",
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
@ -236,18 +217,12 @@ class LoopNode(BaseNode):
)
except Exception as e:
# Loop failed
logger.exception("Loop run failed")
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopFailedEvent(
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
"completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -255,207 +230,60 @@ class LoopNode(BaseNode):
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
)
finally:
# Clean up
variable_pool.remove([self.node_id, "index"])
def _run_single_loop(
self,
*,
graph_engine: "GraphEngine",
loop_graph: Graph,
variable_pool: "VariablePool",
loop_variable_selectors: dict,
break_conditions: list,
logical_operator: Literal["and", "or"],
condition_processor: ConditionProcessor,
current_index: int,
start_at: datetime,
inputs: dict,
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
"""Run a single loop iteration.
Returns:
dict: {'check_break_result': bool}
"""
# Run workflow
rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"loop {self.node_id} current index not found")
current_index = current_index_variable.value
) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]:
reach_break_node = False
for event in graph_engine.run():
if isinstance(event, GraphNodeEventBase):
self._append_loop_info_to_event(event=event, loop_run_index=current_index)
check_break_result = False
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
event.in_loop_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.LOOP_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START:
continue
if isinstance(event, GraphNodeEventBase):
yield event
if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END:
reach_break_node = True
if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error)
if (
isinstance(event, NodeRunSucceededEvent)
and event.node_type == NodeType.LOOP_END
and not isinstance(event, NodeRunStreamChunkEvent)
):
# Check if variables in break conditions exist and process conditions
# Allow loop internal variables to be used in break conditions
available_conditions = []
for condition in break_conditions:
variable = self.graph_runtime_state.variable_pool.get(condition.variable_selector)
if variable:
available_conditions.append(condition)
for loop_var in self._node_data.loop_variables or []:
key, sel = loop_var.label, [self._node_id, loop_var.label]
segment = self.graph_runtime_state.variable_pool.get(sel)
self._node_data.outputs[key] = segment.value if segment else None
self._node_data.outputs["loop_round"] = current_index + 1
# Process conditions if at least one variable is available
if available_conditions:
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=available_conditions,
operator=logical_operator,
)
if check_break_result:
break
else:
check_break_result = True
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
break
return reach_break_node
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# Loop run failed
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
),
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
)
},
)
)
return {"check_break_result": True}
elif isinstance(event, NodeRunFailedEvent):
# Loop run failed
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
)
)
return {"check_break_result": True}
else:
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs: dict[str, Segment | int | None] = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment
else:
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self._node_data.outputs = _outputs
if check_break_result:
return {"check_break_result": True}
# Move to next loop
next_index = current_index + 1
variable_pool.add([self.node_id, "index"], next_index)
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
index=next_index,
pre_loop_output=self._node_data.outputs,
)
return {"check_break_result": False}
def _handle_event_metadata(
def _append_loop_info_to_event(
self,
*,
event: BaseNodeEvent | InNodeEvent,
iter_run_index: int,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
"""
if not isinstance(event, BaseNodeEvent):
return event
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
metadata = {
**metadata,
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
return event
event: GraphNodeEventBase,
loop_run_index: int,
):
event.in_loop_id = self._node_id
loop_metadata = {
WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index,
}
current_metadata = event.node_run_result.metadata
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -470,13 +298,13 @@ class LoopNode(BaseNode):
variable_mapping = {}
# init graph
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
# Extract loop node IDs statically from graph_config
if not loop_graph:
raise ValueError("loop graph not found")
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
for sub_node_id, sub_node_config in loop_graph.node_id_config_mapping.items():
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
continue
@ -515,12 +343,35 @@ class LoopNode(BaseNode):
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
}
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
This method statically analyzes the graph configuration to find all nodes
that are part of the specified loop, without creating actual node instances.
:param graph_config: the complete graph configuration
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
loop_node_ids.add(node_id)
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
@ -552,3 +403,47 @@ class LoopNode(BaseNode):
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from.value,
invoke_from=self.invoke_from.value,
call_depth=self.workflow_call_depth,
)
# Create a new GraphRuntimeState for this iteration
graph_runtime_state_copy = GraphRuntimeState(
variable_pool=self.graph_runtime_state.variable_pool,
start_at=start_at.timestamp(),
)
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the loop graph with the new node factory
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=loop_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
return graph_engine

View File

@ -1,27 +1,26 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(BaseNode):
class LoopStartNode(Node):
"""
Loop Start Node.
"""
_node_type = NodeType.LOOP_START
node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +29,7 @@ class LoopStartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, final
from typing_extensions import override
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
@final
class DifyNodeFactory(NodeFactory):
"""
Default implementation of NodeFactory that uses the traditional node mapping.
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
and instantiating the appropriate node class.
"""
def __init__(
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@override
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
"""
# Get node_id from config
node_id = node_config.get("id")
if not is_str(node_id):
raise ValueError("Node config missing id")
# Get node type from config
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_type_str = node_data.get("type")
if not is_str(node_type_str):
raise ValueError(f"Node {node_id} missing or invalid type information")
try:
node_type = NodeType(node_type_str)
except ValueError:
raise ValueError(f"Unknown node type: {node_type_str}")
# Get node class
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
node_class = node_mapping.get(LATEST_VERSION)
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
node_instance = node_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Initialize node with provided data
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
# If node has fail branch, change execution type to branch
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
node_instance.execution_type = NodeExecutionType.BRANCH
return node_instance

View File

@ -1,15 +1,17 @@
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end import EndNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
@ -33,7 +35,7 @@ LATEST_VERSION = "latest"
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
@ -135,6 +137,14 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
"2": AgentNode,
"1": AgentNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
NodeType.TRIGGER_WEBHOOK: {
LATEST_VERSION: TriggerWebhookNode,
"1": TriggerWebhookNode,

View File

@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import (
BaseModel,
@ -31,8 +31,6 @@ _VALID_PARAMETER_TYPES = frozenset(
def _validate_type(parameter_type: str) -> SegmentType:
if not isinstance(parameter_type, str):
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
@ -43,10 +41,6 @@ def _validate_type(parameter_type: str) -> SegmentType:
return SegmentType(parameter_type)
class _ParameterConfigError(Exception):
pass
class ParameterConfig(BaseModel):
"""
Parameter Config.
@ -54,7 +48,7 @@ class ParameterConfig(BaseModel):
name: str
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: Optional[list[str]] = None
options: list[str] | None = None
description: str
required: bool
@ -92,8 +86,8 @@ class ParameterExtractorNodeData(BaseNodeData):
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
instruction: str | None = None
memory: MemoryConfig | None = None
reasoning_mode: Literal["function_call", "prompt"]
vision: VisionConfig = Field(default_factory=VisionConfig)
@ -102,7 +96,7 @@ class ParameterExtractorNodeData(BaseNodeData):
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def get_parameter_json_schema(self) -> dict:
def get_parameter_json_schema(self):
"""
Get parameter json schema.

View File

@ -63,7 +63,7 @@ class InvalidValueTypeError(ParameterExtractorNodeError):
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
) -> None:
):
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"

View File

@ -3,14 +3,14 @@ import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -27,19 +27,17 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
@ -52,6 +50,7 @@ from .exc import (
)
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_PROMPT,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
@ -85,19 +84,19 @@ def extract_json(text):
return None
class ParameterExtractorNode(BaseNode):
class ParameterExtractorNode(Node):
"""
Parameter Extractor Node.
"""
_node_type = NodeType.PARAMETER_EXTRACTOR
node_type = NodeType.PARAMETER_EXTRACTOR
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -106,7 +105,7 @@ class ParameterExtractorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -115,11 +114,11 @@ class ParameterExtractorNode(BaseNode):
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"model": {
"prompt_templates": {
@ -294,7 +293,7 @@ class ParameterExtractorNode(BaseNode):
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
@ -305,8 +304,6 @@ class ParameterExtractorNode(BaseNode):
)
# handle invoke result
if not isinstance(invoke_result, LLMResult):
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content or ""
if not isinstance(text, str):
@ -318,9 +315,6 @@ class ParameterExtractorNode(BaseNode):
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
if text is None:
text = ""
return text, usage, tool_call
def _generate_function_call_prompt(
@ -329,9 +323,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
@ -411,9 +405,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
@ -449,9 +443,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate completion prompt.
@ -483,9 +477,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate chat prompt.
@ -545,7 +539,7 @@ class ParameterExtractorNode(BaseNode):
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
@ -584,20 +578,21 @@ class ParameterExtractorNode(BaseNode):
return int(value)
elif isinstance(value, (int, float)):
return value
elif not isinstance(value, str):
return None
if "." in value:
try:
return float(value)
except ValueError:
return None
elif isinstance(value, str):
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
"""
Transform result into standard format.
"""
@ -656,7 +651,7 @@ class ParameterExtractorNode(BaseNode):
return transformed_result
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
def _extract_complete_json_response(self, result: str) -> dict | None:
"""
Extract complete json response.
"""
@ -671,7 +666,7 @@ class ParameterExtractorNode(BaseNode):
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
"""
Extract json from tool call.
"""
@ -690,7 +685,7 @@ class ParameterExtractorNode(BaseNode):
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
def _generate_default_result(self, data: ParameterExtractorNodeData):
"""
Generate default result.
"""
@ -698,7 +693,7 @@ class ParameterExtractorNode(BaseNode):
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0
elif parameter.type == "bool":
elif parameter.type == "boolean":
result[parameter.name] = False
elif parameter.type in {"string", "select"}:
result[parameter.name] = ""
@ -710,7 +705,7 @@ class ParameterExtractorNode(BaseNode):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@ -737,7 +732,7 @@ class ParameterExtractorNode(BaseNode):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -752,7 +747,7 @@ class ParameterExtractorNode(BaseNode):
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
@ -773,7 +768,7 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData):
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
instruction: str | None = None
memory: MemoryConfig | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -10,21 +10,20 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
llm_utils,
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@ -41,11 +40,12 @@ from .template_prompts import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER
class QuestionClassifierNode(Node):
node_type = NodeType.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_node_data: QuestionClassifierNodeData
@ -57,21 +57,15 @@ class QuestionClassifierNode(BaseNode):
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
@ -83,10 +77,10 @@ class QuestionClassifierNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -95,7 +89,7 @@ class QuestionClassifierNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -187,7 +181,8 @@ class QuestionClassifierNode(BaseNode):
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
@ -259,6 +254,7 @@ class QuestionClassifierNode(BaseNode):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
@ -275,12 +271,13 @@ class QuestionClassifierNode(BaseNode):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.
:param filters: filter by node config parameters (not used in this implementation).
:return:
"""
# filters parameter is not used in this node type
return {"type": "question-classifier", "config": {"instructions": ""}}
def _calculate_rest_token(
@ -288,7 +285,7 @@ class QuestionClassifierNode(BaseNode):
node_data: QuestionClassifierNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
@ -331,7 +328,7 @@ class QuestionClassifierNode(BaseNode):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)

View File

@ -1,24 +1,24 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode):
_node_type = NodeType.START
class StartNode(Node):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -27,7 +27,7 @@ class StartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class TemplateTransformNodeData(BaseNodeData):

View File

@ -1,27 +1,26 @@
import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode):
_node_type = NodeType.TEMPLATE_TRANSFORM
class TemplateTransformNode(Node):
node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +29,7 @@ class TemplateTransformNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -40,7 +39,7 @@ class TemplateTransformNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.
@ -57,7 +56,7 @@ class TemplateTransformNode(BaseNode):
def _run(self) -> NodeRunResult:
# Get variables
variables = {}
variables: dict[str, Any] = {}
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)

View File

@ -1,28 +1,28 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
@ -35,27 +35,31 @@ from .exc import (
ToolParameterError,
)
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
class ToolNode(BaseNode):
class ToolNode(Node):
"""
Tool Node
"""
_node_type = NodeType.TOOL
node_type = NodeType.TOOL
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ToolNodeData.model_validate(data)
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
node_data = self._node_data
@ -78,11 +82,11 @@ class ToolNode(BaseNode):
if node_data.version != "1" or node_data.tool_node_version != "1":
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -115,13 +119,12 @@ class ToolNode(BaseNode):
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -139,11 +142,11 @@ class ToolNode(BaseNode):
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self.node_id,
node_id=self._node_id,
)
except ToolInvokeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -152,8 +155,8 @@ class ToolNode(BaseNode):
)
)
except PluginInvokeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -165,8 +168,8 @@ class ToolNode(BaseNode):
)
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -179,7 +182,7 @@ class ToolNode(BaseNode):
self,
*,
tool_parameters: Sequence[ToolParameter],
variable_pool: VariablePool,
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
) -> dict[str, Any]:
@ -220,7 +223,7 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
@ -238,6 +241,8 @@ class ToolNode(BaseNode):
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
@ -310,17 +315,25 @@ class ToolNode(BaseNode):
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
# JSON message handling for tool node
if message.message.json_object is not None:
if message.message.json_object:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
@ -332,8 +345,10 @@ class ToolNode(BaseNode):
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
@ -393,8 +408,24 @@ class ToolNode(BaseNode):
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[self._node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
@ -431,7 +462,8 @@ class ToolNode(BaseNode):
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "constant":
pass
@ -439,7 +471,7 @@ class ToolNode(BaseNode):
return result
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -448,7 +480,7 @@ class ToolNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -457,10 +489,6 @@ class ToolNode(BaseNode):
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel
from core.variables.types import SegmentType
@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData):
type: str = "variable-assigner"
output_type: str
variables: list[list[str]]
advanced_settings: Optional[AdvancedSettings] = None
advanced_settings: AdvancedSettings | None = None

View File

@ -1,24 +1,23 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(BaseNode):
_node_type = NodeType.VARIABLE_AGGREGATOR
class VariableAggregatorNode(Node):
node_type = NodeType.VARIABLE_AGGREGATOR
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -27,7 +26,7 @@ class VariableAggregatorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -16,7 +16,7 @@ class UpdatedVariable(BaseModel):
name: str
selector: Sequence[str]
value_type: SegmentType
new_value: Any
new_value: Any = None
_T = TypeVar("_T", bound=MutableMapping[str, Any])
@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
if len(selector) < SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
_, var_name = selector[:2]
return UpdatedVariable(
name=var_name,
selector=list(selector[:2]),

View File

@ -1,29 +1,19 @@
from sqlalchemy import Engine, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from models.engine import db
from models.workflow import ConversationVariable
from extensions.ext_database import db
from models import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
_engine: Engine | None
def __init__(self, engine: Engine | None = None) -> None:
self._engine = engine
def _get_engine(self) -> Engine:
if self._engine:
return self._engine
return db.engine
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(self._get_engine()) as session:
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")

View File

@ -1,15 +1,15 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
@ -18,22 +18,22 @@ from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
class VariableAssignerNode(Node):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -42,7 +42,7 @@ class VariableAssignerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -56,20 +56,14 @@ class VariableAssignerNode(BaseNode):
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
) -> None:
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
self._conv_var_updater_factory = conv_var_updater_factory
@ -123,13 +117,8 @@ class VariableAssignerNode(BaseNode):
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
if income_value is None:
raise VariableOperatorNodeError("income value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

View File

@ -1,7 +1,7 @@
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
@ -18,9 +18,9 @@ class VariableOperationItem(BaseModel):
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
# 3. During the variable updating procedure: The `value` field is reassigned to hold
# the resolved actual value that will be applied to the target variable.
value: Any | None = None
value: Any = None
class VariableAssignerNodeData(BaseNodeData):
version: str = "2"
items: Sequence[VariableOperationItem]
items: Sequence[VariableOperationItem] = Field(default_factory=list)

View File

@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError):
class InvalidDataError(VariableOperatorNodeError):
def __init__(self, message: str) -> None:
def __init__(self, message: str):
super().__init__(message)

View File

@ -25,8 +25,6 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
# Only array variable can be appended or extended
# Only array variable can have elements removed
return variable_type.is_array_type()
case _:
return False
def is_variable_input_supported(*, operation: Operation):

View File

@ -1,17 +1,16 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
@ -53,15 +52,15 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
class VariableAssignerNode(Node):
node_type = NodeType.VARIABLE_ASSIGNER
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -70,7 +69,7 @@ class VariableAssignerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -79,6 +78,23 @@ class VariableAssignerNode(BaseNode):
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
Returns True if this node updates any of the requested conversation variables.
"""
# Check each item in this Variable Assigner node
for item in self._node_data.items:
# Convert the item's variable_selector to tuple for comparison
item_selector_tuple = tuple(item.variable_selector)
# Check if this item updates any of the requested variables
if item_selector_tuple in variable_selectors:
return True
return False
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@ -258,5 +274,3 @@ class VariableAssignerNode(BaseNode):
if not variable.value:
return variable.value
return variable.value[:-1]
case _:
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)