mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 07:16:20 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -1,3 +1,3 @@
|
||||
from .enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
__all__ = ["NodeType"]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
@ -9,16 +9,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
@ -29,17 +25,25 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
@ -57,19 +61,23 @@ from .exc import (
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
class AgentNode(BaseNode):
|
||||
|
||||
class AgentNode(Node):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
_node_type = NodeType.AGENT
|
||||
node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -78,7 +86,7 @@ class AgentNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -91,7 +99,9 @@ class AgentNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -99,12 +109,12 @@ class AgentNode(BaseNode):
|
||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
)
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
@ -139,8 +149,8 @@ class AgentNode(BaseNode):
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
@ -158,16 +168,16 @@ class AgentNode(BaseNode):
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_type=self.type_,
|
||||
node_id=self.node_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
@ -181,7 +191,7 @@ class AgentNode(BaseNode):
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
strategy: "PluginAgentStrategy",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
@ -320,7 +330,7 @@ class AgentNode(BaseNode):
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
@ -339,10 +349,11 @@ class AgentNode(BaseNode):
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> InvokeCredentials:
|
||||
) -> "InvokeCredentials":
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
@ -388,6 +399,8 @@ class AgentNode(BaseNode):
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
try:
|
||||
@ -401,7 +414,7 @@ class AgentNode(BaseNode):
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
@ -450,7 +463,9 @@ class AgentNode(BaseNode):
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def _filter_mcp_type_tool(
|
||||
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
@ -473,11 +488,13 @@ class AgentNode(BaseNode):
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator:
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
@ -491,7 +508,7 @@ class AgentNode(BaseNode):
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage: LLMUsage | None = None
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
@ -553,7 +570,11 @@ class AgentNode(BaseNode):
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
@ -564,13 +585,17 @@ class AgentNode(BaseNode):
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if message.message.json_object is not None:
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
@ -587,8 +612,10 @@ class AgentNode(BaseNode):
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
@ -639,7 +666,7 @@ class AgentNode(BaseNode):
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
@ -652,7 +679,7 @@ class AgentNode(BaseNode):
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
@ -673,7 +700,7 @@ class AgentNode(BaseNode):
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.id,
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
@ -689,8 +716,24 @@ class AgentNode(BaseNode):
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from enum import Enum, StrEnum
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(Enum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = auto()
|
||||
OPEN = auto()
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AgentNodeError(Exception):
|
||||
"""Base exception for all agent node errors."""
|
||||
|
||||
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
|
||||
class AgentStrategyError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent strategy."""
|
||||
|
||||
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
|
||||
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
|
||||
self.strategy_name = strategy_name
|
||||
self.provider_name = provider_name
|
||||
super().__init__(message)
|
||||
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
|
||||
class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
"""Exception raised when the specified agent strategy is not found."""
|
||||
|
||||
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
|
||||
def __init__(self, strategy_name: str, provider_name: str | None = None):
|
||||
super().__init__(
|
||||
f"Agent strategy '{strategy_name}' not found"
|
||||
+ (f" for provider '{provider_name}'" if provider_name else ""),
|
||||
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
class AgentInvocationError(AgentNodeError):
|
||||
"""Exception raised when there's an error invoking the agent."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
|
||||
class AgentParameterError(AgentNodeError):
|
||||
"""Exception raised when there's an error with agent parameters."""
|
||||
|
||||
def __init__(self, message: str, parameter_name: Optional[str] = None):
|
||||
def __init__(self, message: str, parameter_name: str | None = None):
|
||||
self.parameter_name = parameter_name
|
||||
super().__init__(message)
|
||||
|
||||
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
|
||||
class AgentVariableError(AgentNodeError):
|
||||
"""Exception raised when there's an error with variables in the agent node."""
|
||||
|
||||
def __init__(self, message: str, variable_name: Optional[str] = None):
|
||||
def __init__(self, message: str, variable_name: str | None = None):
|
||||
self.variable_name = variable_name
|
||||
super().__init__(message)
|
||||
|
||||
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
|
||||
class ToolFileError(AgentNodeError):
|
||||
"""Exception raised when there's an error with a tool file."""
|
||||
|
||||
def __init__(self, message: str, file_id: Optional[str] = None):
|
||||
def __init__(self, message: str, file_id: str | None = None):
|
||||
self.file_id = file_id
|
||||
super().__init__(message)
|
||||
|
||||
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
|
||||
class AgentMessageTransformError(AgentNodeError):
|
||||
"""Exception raised when there's an error transforming agent messages."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
|
||||
class AgentModelError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the model used by the agent."""
|
||||
|
||||
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
|
||||
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
|
||||
class AgentMemoryError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent's memory."""
|
||||
|
||||
def __init__(self, message: str, conversation_id: Optional[str] = None):
|
||||
def __init__(self, message: str, conversation_id: str | None = None):
|
||||
self.conversation_id = conversation_id
|
||||
super().__init__(message)
|
||||
|
||||
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
variable_name: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
actual_type: Optional[str] = None,
|
||||
variable_name: str | None = None,
|
||||
expected_type: str | None = None,
|
||||
actual_type: str | None = None,
|
||||
):
|
||||
self.variable_name = variable_name
|
||||
self.expected_type = expected_type
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
from .answer_node import AnswerNode
|
||||
from .entities import AnswerStreamGenerateRoute
|
||||
|
||||
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]
|
||||
|
||||
@ -1,31 +1,26 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerNode(BaseNode):
|
||||
_node_type = NodeType.ANSWER
|
||||
class AnswerNode(Node):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -34,7 +29,7 @@ class AnswerNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -48,35 +43,29 @@ class AnswerNode(BaseNode):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
|
||||
|
||||
answer = ""
|
||||
files = []
|
||||
for part in generate_routes:
|
||||
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(value_selector)
|
||||
if variable:
|
||||
if isinstance(variable, FileSegment):
|
||||
files.append(variable.value)
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
files.extend(variable.value)
|
||||
answer += variable.markdown
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
|
||||
files = self._extract_files_from_segments(segments.value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
|
||||
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
|
||||
)
|
||||
|
||||
def _extract_files_from_segments(self, segments: Sequence[Segment]):
|
||||
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
|
||||
|
||||
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
|
||||
This method flattens all files into a single list.
|
||||
"""
|
||||
files = []
|
||||
for segment in segments:
|
||||
if isinstance(segment, FileSegment):
|
||||
# Single file - wrap in list for consistency
|
||||
files.append(segment.value)
|
||||
elif isinstance(segment, ArrayFileSegment):
|
||||
# Multiple files - extend the list
|
||||
files.extend(segment.value)
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@ -96,3 +85,12 @@ class AnswerNode(BaseNode):
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
|
||||
Returns:
|
||||
Template instance for this Answer node
|
||||
"""
|
||||
return Template.from_answer_template(self._node_data.answer)
|
||||
|
||||
@ -1,174 +0,0 @@
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamGeneratorRouter:
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selectors of answer nodes
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
||||
for answer_node_id, node_config in node_id_config_mapping.items():
|
||||
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
generate_route = cls._extract_generate_route_selectors(node_config)
|
||||
answer_generate_route[answer_node_id] = generate_route
|
||||
|
||||
# fetch answer dependencies
|
||||
answer_node_ids = list(answer_generate_route.keys())
|
||||
answer_dependencies = cls._fetch_answers_dependencies(
|
||||
answer_node_ids=answer_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
)
|
||||
|
||||
return AnswerStreamGenerateRoute(
|
||||
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split("Ω"):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(text=part))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = AnswerNodeData(**config.get("data", {}))
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace("{{", "").replace("}}", "")
|
||||
return part.startswith("{{") and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _fetch_answers_dependencies(
|
||||
cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch answer dependencies
|
||||
:param answer_node_ids: answer node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
answer_dependencies: dict[str, list[str]] = {}
|
||||
for answer_node_id in answer_node_ids:
|
||||
if answer_dependencies.get(answer_node_id) is None:
|
||||
answer_dependencies[answer_node_id] = []
|
||||
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=answer_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies,
|
||||
)
|
||||
|
||||
return answer_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_answer_dependencies(
|
||||
cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
:param answer_node_id: answer node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param answer_dependencies: answer dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
|
||||
if (
|
||||
source_node_type
|
||||
in {
|
||||
NodeType.ANSWER,
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.LOOP,
|
||||
NodeType.VARIABLE_ASSIGNER,
|
||||
}
|
||||
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
|
||||
):
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies,
|
||||
)
|
||||
@ -1,199 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
self.route_position = {}
|
||||
for answer_node_id in self.generate_routes.answer_generate_route:
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id or event.in_loop_id:
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
|
||||
stream_out_answer_node_ids
|
||||
)
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(
|
||||
self, event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for answer_node_id in self.route_position:
|
||||
# all depends on answer node id not in rest node ids
|
||||
if event.route_node_state.node_id != answer_node_id and (
|
||||
answer_node_id not in self.rest_node_ids
|
||||
or not all(
|
||||
dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=route_chunk.text,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
from_variable_selector=[answer_node_id, "answer"],
|
||||
node_version=event.node_version,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get(value_selector)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=list(value_selector),
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, route_position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
# Remove current node id from answer dependencies to support stream output if it is a success branch
|
||||
answer_dependencies = self.generate_routes.answer_dependencies
|
||||
edge_mapping = self.graph.edge_mapping.get(event.node_id)
|
||||
success_edge = (
|
||||
next(
|
||||
(
|
||||
edge
|
||||
for edge in edge_mapping
|
||||
if edge.run_condition
|
||||
and edge.run_condition.type == "branch_identify"
|
||||
and edge.run_condition.branch_identify == "success-branch"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if edge_mapping
|
||||
else None
|
||||
)
|
||||
if (
|
||||
event.node_id in answer_dependencies[answer_node_id]
|
||||
and success_edge
|
||||
and success_edge.target_node_id == answer_node_id
|
||||
):
|
||||
answer_dependencies[answer_node_id].remove(event.node_id)
|
||||
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
|
||||
# all depends on answer node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids):
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
continue
|
||||
|
||||
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
|
||||
|
||||
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
|
||||
continue
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_answer_node_ids.append(answer_node_id)
|
||||
|
||||
return stream_out_answer_node_ids
|
||||
@ -1,109 +0,0 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
|
||||
@abstractmethod
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
# remove finished node id
|
||||
self.rest_node_ids.remove(finished_node_id)
|
||||
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids: list[str] = []
|
||||
unreachable_first_node_ids: list[str] = []
|
||||
if finished_node_id not in self.graph.edge_mapping:
|
||||
logger.warning("node %s has no edge mapping", finished_node_id)
|
||||
return
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (
|
||||
edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify
|
||||
):
|
||||
# remove unreachable nodes
|
||||
# FIXME: because of the code branch can combine directly, so for answer node
|
||||
# we remove the node maybe shortcut the answer node, so comment this code for now
|
||||
# there is not effect on the answer node and the workflow, when we have a better solution
|
||||
# we can open this code. Issues: #11542 #9560 #10638 #10564
|
||||
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
|
||||
# if "answer" in ids:
|
||||
# continue
|
||||
# else:
|
||||
# reachable_node_ids.extend(ids)
|
||||
|
||||
# The branch_identify parameter is added to ensure that
|
||||
# only nodes in the correct logical branch are included.
|
||||
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
|
||||
reachable_node_ids.extend(ids)
|
||||
else:
|
||||
# if the condition edge in parallel, and the target node is not in parallel, we should not remove it
|
||||
# Issues: #13626
|
||||
if (
|
||||
finished_node_id in self.graph.node_parallel_mapping
|
||||
and edge.target_node_id not in self.graph.node_parallel_mapping
|
||||
):
|
||||
continue
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
|
||||
if node_id not in self.rest_node_ids:
|
||||
self.rest_node_ids.append(node_id)
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
# Only follow edges that match the branch_identify or have no run_condition
|
||||
if edge.run_condition and edge.run_condition.branch_identify:
|
||||
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
if node_id in reachable_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
|
||||
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
|
||||
class ChunkType(Enum):
|
||||
VAR = "var"
|
||||
TEXT = "text"
|
||||
class ChunkType(StrEnum):
|
||||
VAR = auto()
|
||||
TEXT = auto()
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .node import BaseNode
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNode",
|
||||
"BaseNodeData",
|
||||
]
|
||||
|
||||
@ -1,12 +1,37 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.workflow.nodes.base.exc import DefaultValueTypeError
|
||||
from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
|
||||
from .exc import DefaultValueTypeError
|
||||
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: Sequence[str]
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
@ -19,16 +44,13 @@ class DefaultValueType(StrEnum):
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
NumberType = Union[int, float]
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str) -> Any:
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
@ -51,9 +73,6 @@ class DefaultValue(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> "DefaultValue":
|
||||
if self.type is None:
|
||||
raise DefaultValueTypeError("type field is required")
|
||||
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
DefaultValueType.STRING: {
|
||||
@ -61,7 +80,7 @@ class DefaultValue(BaseModel):
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": NumberType,
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
@ -70,7 +89,7 @@ class DefaultValue(BaseModel):
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": NumberType,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
@ -107,24 +126,12 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
@ -135,7 +142,7 @@ class BaseNodeData(ABC, BaseModel):
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
@ -150,7 +157,7 @@ class BaseIterationState(BaseModel):
|
||||
|
||||
|
||||
class BaseLoopNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseLoopState(BaseModel):
|
||||
|
||||
@ -1,81 +1,175 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
LoopFailedEvent,
|
||||
LoopNextEvent,
|
||||
LoopStartedEvent,
|
||||
LoopSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseNode:
|
||||
_node_type: ClassVar[NodeType]
|
||||
class Node:
|
||||
node_type: ClassVar["NodeType"]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_type = graph_init_params.workflow_type
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = graph_init_params.user_from
|
||||
self.invoke_from = graph_init_params.invoke_from
|
||||
self.user_from = UserFrom(graph_init_params.user_from)
|
||||
self.invoke_from = InvokeFrom(graph_init_params.invoke_from)
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.previous_node_id = previous_node_id
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_id = node_id
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
# Generate a single node execution ID to use for all events
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.title,
|
||||
in_iteration_id=None,
|
||||
start_at=self._start_at,
|
||||
)
|
||||
|
||||
# === FIXME(-LAN-): Needs to refactor.
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
|
||||
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
# ===
|
||||
yield start_event
|
||||
|
||||
try:
|
||||
result = self._run()
|
||||
|
||||
# Handle NodeRunResult
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield self._convert_node_run_result_to_graph_node_event(result)
|
||||
return
|
||||
|
||||
# Handle event stream
|
||||
for event in result:
|
||||
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event.id = self._node_execution_id
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self.node_id)
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(run_result=result)
|
||||
else:
|
||||
yield from result
|
||||
yield NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(
|
||||
@ -140,13 +234,21 @@ class BaseNode:
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
return {}
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this node blocks the output of specific variables.
|
||||
|
||||
@property
|
||||
def type_(self) -> NodeType:
|
||||
return self._node_type
|
||||
This method is used to determine if a node must complete execution before
|
||||
the specified variables can be used in streaming output.
|
||||
|
||||
:param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str'))
|
||||
:return: True if this node blocks output of any of the specified variables, False otherwise
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@ -158,10 +260,6 @@ class BaseNode:
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
@ -170,7 +268,7 @@ class BaseNode:
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
@ -185,7 +283,7 @@ class BaseNode:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
...
|
||||
|
||||
@ -201,7 +299,7 @@ class BaseNode:
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
@ -216,7 +314,7 @@ class BaseNode:
|
||||
return self._get_title()
|
||||
|
||||
@property
|
||||
def description(self) -> Optional[str]:
|
||||
def description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
return self._get_description()
|
||||
|
||||
@ -224,3 +322,198 @@ class BaseNode:
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
error=result.error,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"result status {result.status} not supported")
|
||||
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=event.node_run_result,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"Node {self._node_id} does not support status {event.node_run_result.status}"
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
return NodeRunAgentLogEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
message_id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||
return NodeRunLoopNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||
return NodeRunLoopSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||
return NodeRunLoopFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||
return NodeRunIterationStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||
return NodeRunIterationNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||
return NodeRunIterationSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||
return NodeRunIterationFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
node_version=self.version(),
|
||||
)
|
||||
|
||||
148
api/core/workflow/nodes/base/template.py
Normal file
148
api/core/workflow/nodes/base/template.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Template structures for Response nodes (Answer and End).
|
||||
|
||||
This module provides a unified template structure for both Answer and End nodes,
|
||||
similar to SegmentGroup but focused on template representation without values.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemplateSegment(ABC):
|
||||
"""Base class for template segments."""
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the segment."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextSegment(TemplateSegment):
|
||||
"""A text segment in a template."""
|
||||
|
||||
text: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableSegment(TemplateSegment):
|
||||
"""A variable reference segment in a template."""
|
||||
|
||||
selector: Sequence[str]
|
||||
variable_name: str | None = None # Optional variable name for End nodes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{{#" + ".".join(self.selector) + "#}}"
|
||||
|
||||
|
||||
# Type alias for segments
|
||||
TemplateSegmentUnion = Union[TextSegment, VariableSegment]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Template:
|
||||
"""Unified template structure for Response nodes.
|
||||
|
||||
Similar to SegmentGroup, but represents the template structure
|
||||
without variable values - only marking variable selectors.
|
||||
"""
|
||||
|
||||
segments: list[TemplateSegmentUnion]
|
||||
|
||||
@classmethod
|
||||
def from_answer_template(cls, template_str: str) -> "Template":
|
||||
"""Create a Template from an Answer node template string.
|
||||
|
||||
Example:
|
||||
"Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])]
|
||||
|
||||
Args:
|
||||
template_str: The answer template string
|
||||
|
||||
Returns:
|
||||
Template instance
|
||||
"""
|
||||
parser = VariableTemplateParser(template_str)
|
||||
segments: list[TemplateSegmentUnion] = []
|
||||
|
||||
# Extract variable selectors to find all variables
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
var_map = {var.variable: var.value_selector for var in variable_selectors}
|
||||
|
||||
# Parse template to get ordered segments
|
||||
# We need to split the template by variable placeholders while preserving order
|
||||
import re
|
||||
|
||||
# Create a regex pattern that matches variable placeholders
|
||||
pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}"
|
||||
|
||||
# Split template while keeping the delimiters (variable placeholders)
|
||||
parts = re.split(pattern, template_str)
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# Check if this part is a variable reference (odd indices after split)
|
||||
if i % 2 == 1: # Odd indices are variable keys
|
||||
# Remove the # symbols from the variable key
|
||||
var_key = part
|
||||
if var_key in var_map:
|
||||
segments.append(VariableSegment(selector=list(var_map[var_key])))
|
||||
else:
|
||||
# This shouldn't happen with valid templates
|
||||
segments.append(TextSegment(text="{{" + part + "}}"))
|
||||
else:
|
||||
# Even indices are text segments
|
||||
segments.append(TextSegment(text=part))
|
||||
|
||||
return cls(segments=segments)
|
||||
|
||||
@classmethod
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
|
||||
"""Create a Template from an End node outputs configuration.
|
||||
|
||||
End nodes are treated as templates of concatenated variables with newlines.
|
||||
|
||||
Example:
|
||||
[{"variable": "text", "value_selector": ["node1", "text"]},
|
||||
{"variable": "result", "value_selector": ["node2", "result"]}]
|
||||
->
|
||||
[VariableSegment(["node1", "text"]),
|
||||
TextSegment("\n"),
|
||||
VariableSegment(["node2", "result"])]
|
||||
|
||||
Args:
|
||||
outputs_config: List of output configurations with variable and value_selector
|
||||
|
||||
Returns:
|
||||
Template instance
|
||||
"""
|
||||
segments: list[TemplateSegmentUnion] = []
|
||||
|
||||
for i, output in enumerate(outputs_config):
|
||||
if i > 0:
|
||||
# Add newline separator between variables
|
||||
segments.append(TextSegment(text="\n"))
|
||||
|
||||
value_selector = output.get("value_selector", [])
|
||||
variable_name = output.get("variable", "")
|
||||
if value_selector:
|
||||
segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name))
|
||||
|
||||
if len(segments) > 0 and isinstance(segments[-1], TextSegment):
|
||||
segments = segments[:-1]
|
||||
|
||||
return cls(segments=segments)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the template."""
|
||||
return "".join(str(segment) for segment in self.segments)
|
||||
130
api/core/workflow/nodes/base/variable_template_parser.py
Normal file
130
api/core/workflow/nodes/base/variable_template_parser.py
Normal file
@ -0,0 +1,130 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from .entities import VariableSelector
|
||||
|
||||
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
|
||||
|
||||
SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
|
||||
|
||||
|
||||
def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]:
|
||||
parts = SELECTOR_PATTERN.split(template)
|
||||
selectors = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and part[0] == "#" and part[-1] == "#":
|
||||
selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split(".")))
|
||||
return selectors
|
||||
|
||||
|
||||
class VariableTemplateParser:
|
||||
"""
|
||||
!NOTE: Consider to use the new `segments` module instead of this class.
|
||||
|
||||
A class for parsing and manipulating template variables in a string.
|
||||
|
||||
Rules:
|
||||
|
||||
1. Template variables must be enclosed in `{{}}`.
|
||||
2. The template variable Key can only be: #node_id.var1.var2#.
|
||||
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
|
||||
|
||||
Example usage:
|
||||
|
||||
template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}."
|
||||
parser = VariableTemplateParser(template)
|
||||
|
||||
# Extract template variable keys
|
||||
variable_keys = parser.extract()
|
||||
print(variable_keys)
|
||||
# Output: ['#node_id.query.name#', '#node_id.query.age#']
|
||||
|
||||
# Extract variable selectors
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
print(variable_selectors)
|
||||
# Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']),
|
||||
# VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])]
|
||||
|
||||
# Format the template string
|
||||
inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}}
|
||||
formatted_string = parser.format(inputs)
|
||||
print(formatted_string)
|
||||
# Output: "Hello, John! Your age is 25."
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.template = template
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self):
|
||||
"""
|
||||
Extracts all the template variable keys from the template string.
|
||||
|
||||
Returns:
|
||||
A list of template variable keys.
|
||||
"""
|
||||
# Regular expression to match the template rules
|
||||
matches = re.findall(REGEX, self.template)
|
||||
|
||||
first_group_matches = [match[0] for match in matches]
|
||||
|
||||
return list(set(first_group_matches))
|
||||
|
||||
def extract_variable_selectors(self) -> list[VariableSelector]:
|
||||
"""
|
||||
Extracts the variable selectors from the template variable keys.
|
||||
|
||||
Returns:
|
||||
A list of VariableSelector objects representing the variable selectors.
|
||||
"""
|
||||
variable_selectors = []
|
||||
for variable_key in self.variable_keys:
|
||||
remove_hash = variable_key.replace("#", "")
|
||||
split_result = remove_hash.split(".")
|
||||
if len(split_result) < 2:
|
||||
continue
|
||||
|
||||
variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result))
|
||||
|
||||
return variable_selectors
|
||||
|
||||
def format(self, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Formats the template string by replacing the template variables with their corresponding values.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary containing the values for the template variables.
|
||||
|
||||
Returns:
|
||||
The formatted string with template variables replaced by their values.
|
||||
"""
|
||||
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if value is None:
|
||||
value = ""
|
||||
# convert the value to string
|
||||
if isinstance(value, list | dict | bool | int | float):
|
||||
value = str(value)
|
||||
|
||||
# remove template variables if required
|
||||
return VariableTemplateParser.remove_template_variables(value)
|
||||
|
||||
prompt = re.sub(REGEX, replacer, self.template)
|
||||
return re.sub(r"<\|.*?\|>", "", prompt)
|
||||
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str):
|
||||
"""
|
||||
Removes the template variables from the given text.
|
||||
|
||||
Args:
|
||||
text: The text from which to remove the template variables.
|
||||
|
||||
Returns:
|
||||
The text with template variables removed.
|
||||
"""
|
||||
return re.sub(REGEX, r"{\1}", text)
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
@ -9,12 +9,11 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
@ -23,15 +22,15 @@ from .exc import (
|
||||
)
|
||||
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
_node_type = NodeType.CODE
|
||||
class CodeNode(Node):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -40,7 +39,7 @@ class CodeNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -50,7 +49,7 @@ class CodeNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
@ -58,7 +57,7 @@ class CodeNode(BaseNode):
|
||||
"""
|
||||
code_language = CodeLanguage.PYTHON3
|
||||
if filters:
|
||||
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
@ -109,8 +108,6 @@ class CodeNode(BaseNode):
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, str):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
@ -123,8 +120,6 @@ class CodeNode(BaseNode):
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, bool):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
|
||||
|
||||
return value
|
||||
|
||||
@ -137,8 +132,6 @@ class CodeNode(BaseNode):
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int | float):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
@ -161,7 +154,7 @@ class CodeNode(BaseNode):
|
||||
def _transform_result(
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
output_schema: dict[str, CodeNodeData.Output] | None,
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
@ -262,7 +255,13 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
|
||||
value = result.get(output_name)
|
||||
if value is not None and not isinstance(value, (int, float)):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not a number,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
@ -272,8 +271,13 @@ class CodeNode(BaseNode):
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
value = result.get(output_name)
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead"
|
||||
)
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
value=value,
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
@ -283,31 +287,36 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
value = result[output_name]
|
||||
if not isinstance(value, list):
|
||||
if value is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
for i, inner_value in enumerate(value):
|
||||
if not isinstance(inner_value, (int, float)):
|
||||
raise OutputValidationError(
|
||||
f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" a number."
|
||||
)
|
||||
_ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
transformed_result[output_name] = [
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
self._convert_boolean_to_int(v)
|
||||
for v in value
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
@ -370,8 +379,9 @@ class CodeNode(BaseNode):
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
value = result[output_name]
|
||||
if not isinstance(value, list):
|
||||
if value is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
@ -379,10 +389,14 @@ class CodeNode(BaseNode):
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = [
|
||||
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
for i, inner_value in enumerate(value):
|
||||
if inner_value is not None and not isinstance(inner_value, bool):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not a boolean,"
|
||||
f" got {type(inner_value)} instead."
|
||||
)
|
||||
_ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
transformed_result[output_name] = value
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
@ -403,6 +417,7 @@ class CodeNode(BaseNode):
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
@ -411,10 +426,6 @@ class CodeNode(BaseNode):
|
||||
for variable_selector in typed_node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from typing import Annotated, Literal, Optional
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
children: dict[str, Self] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData):
|
||||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
dependencies: Optional[list[Dependency]] = None
|
||||
dependencies: list[Dependency] | None = None
|
||||
|
||||
3
api/core/workflow/nodes/datasource/__init__.py
Normal file
3
api/core/workflow/nodes/datasource/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .datasource_node import DatasourceNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
||||
502
api/core/workflow/nodes/datasource/datasource_node.py
Normal file
502
api/core/workflow/nodes/datasource/datasource_node.py
Normal file
@ -0,0 +1,502 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.file import File
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(Node):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
||||
_node_data: DatasourceNodeData
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DatasourceNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
if not datasource_type_segement:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if not datasource_info_segement:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info_value = datasource_info_segement.value
|
||||
if not isinstance(datasource_info_value, dict):
|
||||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=datasource_info.get("credential_id", ""),
|
||||
)
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
yield from self._transform_message(
|
||||
messages=online_document_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
)
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
yield from self._transform_datasource_file_message(
|
||||
messages=online_drive_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
variable_pool=variable_pool,
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
**datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to invoke datasource: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
datasource_parameters: Sequence[DatasourceParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: DatasourceNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
typed_node_data = DatasourceNodeData.model_validate(node_data)
|
||||
result = {}
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={**variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _transform_datasource_file_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
file = None
|
||||
for message in message_stream:
|
||||
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
if file:
|
||||
variable_pool.add([self._node_id, "file"], file)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
41
api/core/workflow/nodes/datasource/entities.py
Normal file
41
api/core/workflow/nodes/datasource/entities.py
Normal file
@ -0,0 +1,41 @@
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
plugin_id: str
|
||||
provider_name: str # redundancy
|
||||
provider_type: str
|
||||
datasource_name: str | None = "local_file"
|
||||
datasource_configurations: dict[str, Any] | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"] | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||
16
api/core/workflow/nodes/datasource/exc.py
Normal file
16
api/core/workflow/nodes/datasource/exc.py
Normal file
@ -0,0 +1,16 @@
|
||||
class DatasourceNodeError(ValueError):
|
||||
"""Base exception for datasource node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceParameterError(DatasourceNodeError):
|
||||
"""Exception raised for errors in datasource parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceFileError(DatasourceNodeError):
|
||||
"""Exception raised for errors related to datasource files."""
|
||||
|
||||
pass
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any
|
||||
|
||||
import chardet
|
||||
import docx
|
||||
@ -25,11 +25,10 @@ from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
@ -37,20 +36,20 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode):
|
||||
class DocumentExtractorNode(Node):
|
||||
"""
|
||||
Extracts text content from various file types.
|
||||
Supports plain text, PDF, and DOC/DOCX files.
|
||||
"""
|
||||
|
||||
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -59,7 +58,7 @@ class DocumentExtractorNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -428,9 +427,9 @@ def _download_file_content(file: File) -> bytes:
|
||||
raise FileDownloadError("Missing URL for remote file")
|
||||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return cast(bytes, response.content)
|
||||
return response.content
|
||||
else:
|
||||
return cast(bytes, file_manager.download(file))
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
from .end_node import EndNode
|
||||
from .entities import EndStreamParam
|
||||
|
||||
__all__ = ["EndNode", "EndStreamParam"]
|
||||
|
||||
@ -1,23 +1,24 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
|
||||
class EndNode(BaseNode):
|
||||
_node_type = NodeType.END
|
||||
class EndNode(Node):
|
||||
node_type = NodeType.END
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -26,7 +27,7 @@ class EndNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -41,8 +42,10 @@ class EndNode(BaseNode):
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
Run node - collect all outputs at once.
|
||||
|
||||
This method runs after streaming is complete (if streaming was enabled).
|
||||
It collects all output variables and returns them.
|
||||
"""
|
||||
output_variables = self._node_data.outputs
|
||||
|
||||
@ -57,3 +60,15 @@ class EndNode(BaseNode):
|
||||
inputs=outputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
|
||||
Returns:
|
||||
Template instance for this End node
|
||||
"""
|
||||
outputs_config = [
|
||||
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
|
||||
]
|
||||
return Template.from_end_outputs(outputs_config)
|
||||
|
||||
@ -1,152 +0,0 @@
|
||||
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class EndStreamGeneratorRouter:
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_parallel_mapping: dict[str, str],
|
||||
) -> EndStreamParam:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selector of end nodes
|
||||
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
|
||||
for end_node_id, node_config in node_id_config_mapping.items():
|
||||
if node_config.get("data", {}).get("type") != NodeType.END.value:
|
||||
continue
|
||||
|
||||
# skip end node in parallel
|
||||
if end_node_id in node_parallel_mapping:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
|
||||
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
|
||||
|
||||
# fetch end dependencies
|
||||
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
|
||||
end_dependencies = cls._fetch_ends_dependencies(
|
||||
end_node_ids=end_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
)
|
||||
|
||||
return EndStreamParam(
|
||||
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_stream_variable_selector_from_node_data(
|
||||
cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node data
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
value_selectors = []
|
||||
for variable_selector in variable_selectors:
|
||||
if not variable_selector.value_selector:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != "sys" and node_id in node_id_config_mapping:
|
||||
node = node_id_config_mapping[node_id]
|
||||
node_type = node.get("data", {}).get("type")
|
||||
if (
|
||||
variable_selector.value_selector not in value_selectors
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == "text"
|
||||
):
|
||||
value_selectors.append(list(variable_selector.value_selector))
|
||||
|
||||
return value_selectors
|
||||
|
||||
@classmethod
|
||||
def _extract_stream_variable_selector(
|
||||
cls, node_id_config_mapping: dict[str, dict], config: dict
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node config
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = EndNodeData(**config.get("data", {}))
|
||||
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
|
||||
|
||||
@classmethod
|
||||
def _fetch_ends_dependencies(
|
||||
cls,
|
||||
end_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch end dependencies
|
||||
:param end_node_ids: end node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
end_dependencies: dict[str, list[str]] = {}
|
||||
for end_node_id in end_node_ids:
|
||||
if end_dependencies.get(end_node_id) is None:
|
||||
end_dependencies[end_node_id] = []
|
||||
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
current_node_id=end_node_id,
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
|
||||
return end_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_end_dependencies(
|
||||
cls,
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch end dependencies
|
||||
:param current_node_id: current node id
|
||||
:param end_node_id: end node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param end_dependencies: end dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
}:
|
||||
end_dependencies[end_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
@ -1,188 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.end_stream_param = graph.end_stream_param
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
self.has_output = False
|
||||
self.output_node_ids: set[str] = set()
|
||||
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id or event.in_loop_id:
|
||||
if self.has_output and event.node_id not in self.output_node_ids:
|
||||
event.chunk_content = "\n" + event.chunk_content
|
||||
|
||||
self.output_node_ids.add(event.node_id)
|
||||
self.has_output = True
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
|
||||
stream_out_end_node_ids
|
||||
)
|
||||
|
||||
if stream_out_end_node_ids:
|
||||
if self.has_output and event.node_id not in self.output_node_ids:
|
||||
event.chunk_content = "\n" + event.chunk_content
|
||||
|
||||
self.output_node_ids.add(event.node_id)
|
||||
self.has_output = True
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(
|
||||
self, event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for end_node_id, position in self.route_position.items():
|
||||
# all depends on end node id not in rest node ids
|
||||
if event.route_node_state.node_id != end_node_id and (
|
||||
end_node_id not in self.rest_node_ids
|
||||
or not all(
|
||||
dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[end_node_id]
|
||||
|
||||
position = 0
|
||||
value_selectors = []
|
||||
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||
if position >= route_position:
|
||||
value_selectors.append(current_value_selectors)
|
||||
|
||||
position += 1
|
||||
|
||||
for value_selector in value_selectors:
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
value = self.variable_pool.get(value_selector)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
current_node_id = value_selector[0]
|
||||
if self.has_output and current_node_id not in self.output_node_ids:
|
||||
text = "\n" + text
|
||||
|
||||
self.output_node_ids.add(current_node_id)
|
||||
self.has_output = True
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_end_node_ids = []
|
||||
for end_node_id, route_position in self.route_position.items():
|
||||
if end_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on end node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
|
||||
continue
|
||||
|
||||
position = 0
|
||||
value_selector = None
|
||||
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||
if position == route_position:
|
||||
value_selector = current_value_selectors
|
||||
break
|
||||
|
||||
position += 1
|
||||
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_end_node_ids.append(end_node_id)
|
||||
|
||||
return stream_out_end_node_ids
|
||||
@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
LOOP_END = "loop-end"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
TRIGGER_WEBHOOK = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
|
||||
@property
|
||||
def is_start_node(self) -> bool:
|
||||
return self in [
|
||||
NodeType.START,
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
]
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
FAIL_BRANCH = "fail-branch"
|
||||
DEFAULT_VALUE = "default-value"
|
||||
|
||||
|
||||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
from .event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
"ModelInvokeCompletedEvent",
|
||||
"NodeEvent",
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"RunStreamChunkEvent",
|
||||
]
|
||||
@ -1,40 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class RunStreamChunkEvent(BaseModel):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: list[str] = Field(..., description="from variable selector")
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(BaseModel):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class ModelInvokeCompletedEvent(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(BaseModel):
|
||||
"""Node Run Retry event"""
|
||||
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
@ -1,3 +0,0 @@
|
||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
|
||||
NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent
|
||||
@ -1,7 +1,7 @@
|
||||
import mimetypes
|
||||
from collections.abc import Sequence
|
||||
from email.message import Message
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel):
|
||||
|
||||
class HttpRequestNodeAuthorization(BaseModel):
|
||||
type: Literal["no-auth", "api-key"]
|
||||
config: Optional[HttpRequestNodeAuthorizationConfig] = None
|
||||
config: HttpRequestNodeAuthorizationConfig | None = None
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
authorization: HttpRequestNodeAuthorization
|
||||
headers: str
|
||||
params: str
|
||||
body: Optional[HttpRequestNodeBody] = None
|
||||
timeout: Optional[HttpRequestNodeTimeout] = None
|
||||
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
body: HttpRequestNodeBody | None = None
|
||||
timeout: HttpRequestNodeTimeout | None = None
|
||||
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
|
||||
class Response:
|
||||
@ -183,7 +183,7 @@ class Response:
|
||||
return f"{(self.size / 1024 / 1024):.2f} MB"
|
||||
|
||||
@property
|
||||
def parsed_content_disposition(self) -> Optional[Message]:
|
||||
def parsed_content_disposition(self) -> Message | None:
|
||||
content_disposition = self.headers.get("content-disposition", "")
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
|
||||
@ -15,7 +15,7 @@ from core.file import file_manager
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import VariablePool
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
@ -263,9 +263,6 @@ class Executor:
|
||||
if authorization.config is None:
|
||||
raise AuthorizationConfigError("authorization config is required")
|
||||
|
||||
if self.auth.config.api_key is None:
|
||||
raise AuthorizationConfigError("api_key is required")
|
||||
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = "Authorization"
|
||||
|
||||
@ -409,30 +406,25 @@ class Executor:
|
||||
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
for file_entry in self.files:
|
||||
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
|
||||
if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2:
|
||||
if len(file_entry) != 2 or len(file_entry[1]) < 2:
|
||||
continue # skip malformed entries
|
||||
key = file_entry[0]
|
||||
content = file_entry[1][1]
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
# decode content safely
|
||||
if isinstance(content, bytes):
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
elif isinstance(content, str):
|
||||
body_string += content
|
||||
else:
|
||||
body_string += f"[Unsupported content type: {type(content).__name__}]"
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
body_string += "\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
if isinstance(self.content, str):
|
||||
body_string = self.content
|
||||
elif isinstance(self.content, bytes):
|
||||
if isinstance(self.content, bytes):
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
body_string = self.content
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
body_string = urlencode(self.data)
|
||||
elif self.data and self.node_data.body.type == "form-data":
|
||||
|
||||
@ -1,20 +1,18 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories import file_factory
|
||||
|
||||
from .entities import (
|
||||
@ -33,15 +31,15 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode):
|
||||
_node_type = NodeType.HTTP_REQUEST
|
||||
class HttpRequestNode(Node):
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -50,7 +48,7 @@ class HttpRequestNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -60,7 +58,7 @@ class HttpRequestNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
@ -101,7 +99,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
if not response.response.is_success and (self.continue_on_error or self.retry):
|
||||
if not response.response.is_success and (self.error_strategy or self.retry):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
@ -129,7 +127,7 @@ class HttpRequestNode(BaseNode):
|
||||
},
|
||||
)
|
||||
except HttpRequestNodeError as e:
|
||||
logger.warning("http request node %s failed to run: %s", self.node_id, e)
|
||||
logger.warning("http request node %s failed to run: %s", self._node_id, e)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
@ -244,10 +242,6 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
return ArrayFileSegment(value=files)
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData):
|
||||
logical_operator: Literal["and", "or"]
|
||||
conditions: list[Condition]
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
cases: Optional[list[Case]] = None
|
||||
cases: list[Case] | None = None
|
||||
|
||||
@ -1,28 +1,28 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
_node_type = NodeType.IF_ELSE
|
||||
class IfElseNode(Node):
|
||||
node_type = NodeType.IF_ELSE
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -31,7 +31,7 @@ class IfElseNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -49,13 +49,13 @@ class IfElseNode(BaseNode):
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_inputs: dict[str, list] = {"conditions": []}
|
||||
node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []}
|
||||
|
||||
process_data: dict[str, list] = {"condition_results": []}
|
||||
|
||||
input_conditions = []
|
||||
input_conditions: Sequence[Mapping[str, Any]] = []
|
||||
final_result = False
|
||||
selected_case_id = None
|
||||
selected_case_id = "false"
|
||||
condition_processor = ConditionProcessor()
|
||||
try:
|
||||
# Check if the new cases structure is used
|
||||
@ -83,7 +83,7 @@ class IfElseNode(BaseNode):
|
||||
else:
|
||||
# TODO: Update database then remove this
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function(
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated]
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self._node_data.conditions or [],
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
is_parallel: bool = False # open the parallel mode or not
|
||||
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
"""
|
||||
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
|
||||
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@ -1,48 +1,40 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, wait
|
||||
from datetime import datetime
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunResult,
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
BaseParallelBranchEvent,
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunFailedEvent,
|
||||
InNodeEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from factories.variable_factory import build_segment
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@ -54,23 +46,26 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||
|
||||
class IterationNode(BaseNode):
|
||||
|
||||
class IterationNode(Node):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
node_type = NodeType.ITERATION
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
_node_data: IterationNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -79,7 +74,7 @@ class IterationNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -89,7 +84,7 @@ class IterationNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
@ -103,225 +98,266 @@ class IterationNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
|
||||
variable = self._get_iterator_variable()
|
||||
|
||||
if self._is_empty_iteration(variable):
|
||||
yield from self._handle_empty_iteration(variable)
|
||||
return
|
||||
|
||||
iterator_list_value = self._validate_and_get_iterator_list(variable)
|
||||
inputs = {"iterator_selector": iterator_list_value}
|
||||
|
||||
self._validate_start_node()
|
||||
|
||||
started_at = naive_utc_now()
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[object] = []
|
||||
|
||||
yield IterationStartedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
metadata={"iteration_length": len(iterator_list_value)},
|
||||
)
|
||||
|
||||
try:
|
||||
yield from self._execute_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
|
||||
yield from self._handle_iteration_success(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
yield from self._handle_iteration_failure(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
error=e,
|
||||
)
|
||||
|
||||
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
|
||||
# Try our best to preserve the type informat.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
return
|
||||
return variable
|
||||
|
||||
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
|
||||
return isinstance(variable, NoneSegment) or len(variable.value) == 0
|
||||
|
||||
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
|
||||
# Try our best to preserve the type information.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
|
||||
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
|
||||
iterator_list_value = variable.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
inputs = {"iterator_selector": iterator_list_value}
|
||||
|
||||
graph_config = self.graph_config
|
||||
return cast(list[object], iterator_list_value)
|
||||
|
||||
def _validate_start_node(self) -> None:
|
||||
if not self._node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
root_node_id = self._node_data.start_node_id
|
||||
def _execute_iterations(
|
||||
self,
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
if self._node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
yield from self._execute_parallel_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
else:
|
||||
# Sequential mode execution
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield IterationNextEvent(index=index)
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool.add([self.node_id, "index"], 0)
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_type=self.workflow_type,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=iteration_graph,
|
||||
graph_config=graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
start_at = naive_utc_now()
|
||||
|
||||
yield IterationRunStartedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"iterator_length": len(iterator_list_value)},
|
||||
predecessor_node_id=self.previous_node_id,
|
||||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=0,
|
||||
pre_iteration_output=None,
|
||||
duration=None,
|
||||
)
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[Any] = [None] * len(iterator_list_value)
|
||||
try:
|
||||
if self._node_data.is_parallel:
|
||||
futures: list[Future] = []
|
||||
q: Queue = Queue()
|
||||
thread_pool = GraphEngineThreadPool(
|
||||
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
|
||||
# Run the iteration
|
||||
yield from self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs,
|
||||
graph_engine=graph_engine,
|
||||
)
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
future: Future = thread_pool.submit(
|
||||
self._run_single_iter_parallel,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
q=q,
|
||||
context=contextvars.copy_context(),
|
||||
iterator_list_value=iterator_list_value,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
start_at=start_at,
|
||||
graph_engine=graph_engine,
|
||||
iteration_graph=iteration_graph,
|
||||
index=index,
|
||||
item=item,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
future.add_done_callback(thread_pool.task_done_callback)
|
||||
futures.append(future)
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
if isinstance(event, IterationRunNextEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
q.put(None)
|
||||
yield event
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
q.put(None)
|
||||
for f in futures:
|
||||
if not f.done():
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
def _execute_parallel_iterations(
|
||||
self,
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
# Initialize outputs list with None values to maintain order
|
||||
outputs.extend([None] * len(iterator_list_value))
|
||||
|
||||
# Determine the number of parallel workers
|
||||
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
yield IterationNextEvent(index=index)
|
||||
future = executor.submit(
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
# Process completed iterations as they finish
|
||||
for future in as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
|
||||
# Yield all events from this iteration
|
||||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
# Cancel remaining futures and re-raise
|
||||
for f in future_to_index:
|
||||
if f != future:
|
||||
f.cancel()
|
||||
yield event
|
||||
if isinstance(event, IterationRunFailedEvent):
|
||||
q.put(None)
|
||||
yield event
|
||||
except Empty:
|
||||
continue
|
||||
raise IterationNodeError(str(e))
|
||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
outputs[index] = None
|
||||
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[index] = None # Will be filtered later
|
||||
|
||||
# wait all threads
|
||||
wait(futures)
|
||||
else:
|
||||
for _ in range(len(iterator_list_value)):
|
||||
yield from self._run_single_iter(
|
||||
iterator_list_value=iterator_list_value,
|
||||
variable_pool=variable_pool,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
start_at=start_at,
|
||||
graph_engine=graph_engine,
|
||||
iteration_graph=iteration_graph,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs = [output for output in outputs if output is not None]
|
||||
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[:] = [output for output in outputs if output is not None]
|
||||
|
||||
# Flatten the list of lists
|
||||
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
|
||||
outputs = [item for sublist in outputs for item in sublist]
|
||||
output_segment = build_segment(outputs)
|
||||
def _execute_single_iteration_parallel(
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Collect events instead of yielding them directly
|
||||
for event in self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs_temp,
|
||||
graph_engine=graph_engine,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
inputs: dict[str, Sequence[object]],
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
)
|
||||
|
||||
# Yield final success event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": output_segment},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
},
|
||||
)
|
||||
def _handle_iteration_failure(
|
||||
self,
|
||||
started_at: datetime,
|
||||
inputs: dict[str, Sequence[object]],
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
yield IterationFailedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(error),
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(error),
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
# iteration run failed
|
||||
logger.warning("Iteration run failed")
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
# remove iteration variable (item, index) from variable pool after iteration run completed
|
||||
variable_pool.remove([self.node_id, "index"])
|
||||
variable_pool.remove([self.node_id, "item"])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@ -337,14 +373,20 @@ class IterationNode(BaseNode):
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
|
||||
continue
|
||||
|
||||
@ -376,303 +418,115 @@ class IterationNode(BaseNode):
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _handle_event_metadata(
|
||||
def _append_iteration_info_to_event(
|
||||
self,
|
||||
*,
|
||||
event: BaseNodeEvent | InNodeEvent,
|
||||
event: GraphNodeEventBase,
|
||||
iter_run_index: int,
|
||||
parallel_mode_run_id: str | None,
|
||||
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
||||
"""
|
||||
add iteration metadata to event.
|
||||
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
|
||||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
|
||||
):
|
||||
event.in_iteration_id = self._node_id
|
||||
iter_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
|
||||
}
|
||||
if parallel_mode_run_id:
|
||||
# for parallel, the specific branch ID is more important than the sequential index
|
||||
iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
|
||||
if event.route_node_state.node_run_result:
|
||||
current_metadata = event.route_node_state.node_run_result.metadata or {}
|
||||
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
|
||||
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
|
||||
|
||||
return event
|
||||
current_metadata = event.node_run_result.metadata
|
||||
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
|
||||
event.node_run_result.metadata = {**current_metadata, **iter_metadata}
|
||||
|
||||
def _run_single_iter(
|
||||
self,
|
||||
*,
|
||||
iterator_list_value: Sequence[str],
|
||||
variable_pool: VariablePool,
|
||||
inputs: Mapping[str, list],
|
||||
outputs: list,
|
||||
start_at: datetime,
|
||||
outputs: list[object],
|
||||
graph_engine: "GraphEngine",
|
||||
iteration_graph: Graph,
|
||||
iter_run_map: dict[str, float],
|
||||
parallel_mode_run_id: Optional[str] = None,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration
|
||||
"""
|
||||
iter_start_at = naive_utc_now()
|
||||
) -> Generator[GraphNodeEventBase, None, None]:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
index_variable = variable_pool.get([self._node_id, "index"])
|
||||
if not isinstance(index_variable, IntegerVariable):
|
||||
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
|
||||
current_index = index_variable.value
|
||||
for event in rst:
|
||||
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
|
||||
continue
|
||||
|
||||
try:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(index_variable, IntegerVariable):
|
||||
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
|
||||
current_index = index_variable.value
|
||||
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
|
||||
next_index = int(current_index) + 1
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.ITERATION_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(
|
||||
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
||||
)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
if self._node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
else:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||
yield event
|
||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||
result = variable_pool.get(self._node_data.output_selector)
|
||||
if result is None:
|
||||
outputs.append(None)
|
||||
else:
|
||||
outputs.append(result.to_object())
|
||||
return
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
match self._node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
raise IterationNodeError(event.error)
|
||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
outputs.append(None)
|
||||
return
|
||||
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
return
|
||||
elif isinstance(event, InNodeEvent):
|
||||
# event = cast(InNodeEvent, event)
|
||||
metadata_event = self._handle_event_metadata(
|
||||
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
||||
)
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
outputs[current_index] = None
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (naive_utc_now() - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
duration=duration,
|
||||
)
|
||||
return
|
||||
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (naive_utc_now() - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
duration=duration,
|
||||
)
|
||||
return
|
||||
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
outputs[current_index] = None
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
# clean nodes resources
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
invoke_from=self.invoke_from.value,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
# Create a deep copy of the variable pool for each iteration
|
||||
variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True)
|
||||
|
||||
# iteration run failed
|
||||
if self._node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
else:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool_copy.add([self._node_id, "index"], index)
|
||||
variable_pool_copy.add([self._node_id, "item"], item)
|
||||
|
||||
# stop the iterator
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
return
|
||||
yield metadata_event
|
||||
# Create a new GraphRuntimeState for this iteration
|
||||
graph_runtime_state_copy = GraphRuntimeState(
|
||||
variable_pool=variable_pool_copy,
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
|
||||
current_output_segment = variable_pool.get(self._node_data.output_selector)
|
||||
if current_output_segment is None:
|
||||
raise IterationNodeError("iteration output selector not found")
|
||||
current_iteration_output = current_output_segment.value
|
||||
outputs[current_index] = current_iteration_output
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# move to next iteration
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
# Initialize the iteration graph with the new node factory
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
|
||||
)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (naive_utc_now() - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=current_iteration_output or None,
|
||||
duration=duration,
|
||||
)
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
except IterationNodeError as e:
|
||||
logger.warning("Iteration run failed:%s", str(e))
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": None},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=str(e),
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=iteration_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
def _run_single_iter_parallel(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
q: Queue,
|
||||
iterator_list_value: Sequence[str],
|
||||
inputs: Mapping[str, list],
|
||||
outputs: list,
|
||||
start_at: datetime,
|
||||
graph_engine: "GraphEngine",
|
||||
iteration_graph: Graph,
|
||||
index: int,
|
||||
item: Any,
|
||||
iter_run_map: dict[str, float],
|
||||
):
|
||||
"""
|
||||
run single iteration in parallel mode
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
parallel_mode_run_id = uuid.uuid4().hex
|
||||
graph_engine_copy = graph_engine.create_copy()
|
||||
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
|
||||
variable_pool_copy.add([self.node_id, "index"], index)
|
||||
variable_pool_copy.add([self.node_id, "item"], item)
|
||||
for event in self._run_single_iter(
|
||||
iterator_list_value=iterator_list_value,
|
||||
variable_pool=variable_pool_copy,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
start_at=start_at,
|
||||
graph_engine=graph_engine_copy,
|
||||
iteration_graph=iteration_graph,
|
||||
iter_run_map=iter_run_map,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
):
|
||||
q.put(event)
|
||||
graph_engine.graph_runtime_state.total_tokens += graph_engine_copy.graph_runtime_state.total_tokens
|
||||
return graph_engine
|
||||
|
||||
@ -1,27 +1,26 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode):
|
||||
class IterationStartNode(Node):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
|
||||
_node_type = NodeType.ITERATION_START
|
||||
node_type = NodeType.ITERATION_START
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -30,7 +29,7 @@ class IterationStartNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
3
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
3
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .knowledge_index_node import KnowledgeIndexNode
|
||||
|
||||
__all__ = ["KnowledgeIndexNode"]
|
||||
159
api/core/workflow/nodes/knowledge_index/entities.py
Normal file
159
api/core/workflow/nodes/knowledge_index/entities.py
Normal file
@ -0,0 +1,159 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
"""
|
||||
Embedding Setting.
|
||||
"""
|
||||
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class EconomySetting(BaseModel):
|
||||
"""
|
||||
Economy Setting.
|
||||
"""
|
||||
|
||||
keyword_number: int
|
||||
|
||||
|
||||
class RetrievalSetting(BaseModel):
|
||||
"""
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
"""
|
||||
Knowledge Index Setting.
|
||||
"""
|
||||
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_setting: EmbeddingSetting
|
||||
economy_setting: EconomySetting
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
"""
|
||||
File Info.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
|
||||
|
||||
class OnlineDocumentIcon(BaseModel):
|
||||
"""
|
||||
Document Icon.
|
||||
"""
|
||||
|
||||
icon_url: str
|
||||
icon_type: str
|
||||
icon_emoji: str
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
"""
|
||||
Online document info.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
workspace_id: str | None = None
|
||||
page_id: str
|
||||
page_type: str
|
||||
icon: OnlineDocumentIcon | None = None
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
website import info.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
url: str
|
||||
|
||||
|
||||
class GeneralStructureChunk(BaseModel):
|
||||
"""
|
||||
General Structure Chunk.
|
||||
"""
|
||||
|
||||
general_chunks: list[str]
|
||||
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||
|
||||
|
||||
class ParentChildChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Chunk.
|
||||
"""
|
||||
|
||||
parent_content: str
|
||||
child_contents: list[str]
|
||||
|
||||
|
||||
class ParentChildStructureChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Structure Chunk.
|
||||
"""
|
||||
|
||||
parent_child_chunks: list[ParentChildChunk]
|
||||
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||
|
||||
|
||||
class KnowledgeIndexNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge index Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
22
api/core/workflow/nodes/knowledge_index/exc.py
Normal file
22
api/core/workflow/nodes/knowledge_index/exc.py
Normal file
@ -0,0 +1,22 @@
|
||||
class KnowledgeIndexNodeError(ValueError):
|
||||
"""Base class for KnowledgeIndexNode errors."""
|
||||
|
||||
|
||||
class ModelNotExistError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model does not exist."""
|
||||
|
||||
|
||||
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model credentials are not initialized."""
|
||||
|
||||
|
||||
class ModelNotSupportedError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model is not supported."""
|
||||
|
||||
|
||||
class ModelQuotaExceededError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model provider quota is exceeded."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
209
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Normal file
209
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Normal file
@ -0,0 +1,209 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
KnowledgeIndexNodeError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeIndexNode(Node):
|
||||
_node_data: KnowledgeIndexNodeData
|
||||
node_type = NodeType.KNOWLEDGE_INDEX
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = KnowledgeIndexNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeIndexNodeData, self._node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
if invoke_from:
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
|
||||
else:
|
||||
is_preview = False
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
if not chunks:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||
)
|
||||
|
||||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs=outputs,
|
||||
)
|
||||
results = self._invoke_knowledge_index(
|
||||
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
|
||||
|
||||
except KnowledgeIndexNodeError as e:
|
||||
logger.warning("Error when running knowledge index node")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
def _invoke_knowledge_index(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
node_data: KnowledgeIndexNodeData,
|
||||
chunks: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> Any:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||
original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
|
||||
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
if original_document_id:
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id.value)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
index_processor.index(dataset, document, chunks)
|
||||
indexing_end_at = time.perf_counter()
|
||||
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||
# update document status
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
db.session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"batch": batch.value,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"display_status": document.indexing_status,
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
|
||||
Returns:
|
||||
Template instance for this knowledge index node
|
||||
"""
|
||||
return Template(segments=[])
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
@ -91,7 +91,7 @@ SupportedComparisonOperator = Literal[
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Conditon detail
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel):
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal["single", "multiple"]
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
multiple_retrieval_config: MultipleRetrievalConfig | None = None
|
||||
single_retrieval_config: SingleRetrievalConfig | None = None
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
|
||||
@ -4,9 +4,9 @@ import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, text
|
||||
from sqlalchemy import Float, and_, func, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -32,14 +32,11 @@ from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
@ -70,7 +67,7 @@ from .exc import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -83,8 +80,8 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
class KnowledgeRetrievalNode(Node):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
_node_data: KnowledgeRetrievalNodeData
|
||||
|
||||
@ -99,21 +96,15 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
@ -125,10 +116,10 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -137,7 +128,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -197,7 +188,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
process_data={},
|
||||
outputs=outputs, # type: ignore
|
||||
)
|
||||
|
||||
@ -259,7 +250,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
@ -291,7 +282,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
@ -367,15 +358,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.first()
|
||||
stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(stmt)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
@ -422,14 +410,14 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
filters = [] # type: ignore
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None, None
|
||||
@ -443,7 +431,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters, # type: ignore
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
@ -514,7 +502,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> list[dict[str, Any]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
@ -552,7 +541,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
@ -573,15 +563,15 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return []
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
|
||||
):
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
) -> list[Any]:
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return
|
||||
return filters
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
@ -666,6 +656,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Optional, TypeAlias, TypeVar
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
@ -36,15 +35,15 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||
return wrapper
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
class ListOperatorNode(Node):
|
||||
node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -53,7 +52,7 @@ class ListOperatorNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -67,8 +66,8 @@ class ListOperatorNode(BaseNode):
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
inputs: dict[str, list] = {}
|
||||
process_data: dict[str, list] = {}
|
||||
inputs: dict[str, Sequence[object]] = {}
|
||||
process_data: dict[str, Sequence[object]] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
||||
@ -171,27 +170,23 @@ class ListOperatorNode(BaseNode):
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayBooleanSegment):
|
||||
else:
|
||||
if not isinstance(condition.value, bool):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}")
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statment should be unreachable.")
|
||||
return variable
|
||||
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
else:
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable")
|
||||
|
||||
return variable
|
||||
|
||||
@ -305,7 +300,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
if key in {"type", "transfer_method"} and isinstance(value, Sequence):
|
||||
if key in {"type", "transfer_method"}:
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
elif key == "size" and isinstance(value, str):
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
@ -18,7 +18,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
variable_selector: list[str] | None = None
|
||||
|
||||
|
||||
class VisionConfigOptions(BaseModel):
|
||||
@ -51,23 +51,40 @@ class PromptConfig(BaseModel):
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
text: str = ""
|
||||
jinja2_text: Optional[str] = None
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: Optional[str] = None
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: Optional[MemoryConfig] = None
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
# Keep tagged as default for backward compatibility
|
||||
default="tagged",
|
||||
description=(
|
||||
"""
|
||||
Strategy for handling model reasoning output.
|
||||
|
||||
separated: Return clean text (without <think> tags) + reasoning_content field.
|
||||
Recommended for new workflows. Enables safe downstream parsing and
|
||||
workflow variable access: {{#node_id.reasoning_content#}}
|
||||
|
||||
tagged : Return original text (with <think> tags) + reasoning_content field.
|
||||
Maintains full backward compatibility while still providing reasoning_content
|
||||
for workflow automation. Frontend thinking panels work as before.
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError):
|
||||
|
||||
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str) -> None:
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
||||
@ -8,7 +8,7 @@ from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
@ -13,16 +13,16 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
||||
|
||||
@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
@ -107,7 +107,7 @@ def fetch_memory(
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
|
||||
@ -2,8 +2,9 @@ import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
@ -50,22 +51,25 @@ from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunStreamChunkEvent,
|
||||
from core.workflow.entities import GraphInitParams, VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.node_events import (
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
@ -88,17 +92,19 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_type = NodeType.LLM
|
||||
class LLMNode(Node):
|
||||
node_type = NodeType.LLM
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
||||
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
@ -110,21 +116,15 @@ class LLMNode(BaseNode):
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
@ -136,10 +136,10 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LLMNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -148,7 +148,7 @@ class LLMNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -161,12 +161,13 @@ class LLMNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
def _run(self) -> Generator:
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = None
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
@ -182,8 +183,6 @@ class LLMNode(BaseNode):
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
|
||||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
files = (
|
||||
llm_utils.fetch_files(
|
||||
@ -201,9 +200,8 @@ class LLMNode(BaseNode):
|
||||
generator = self._fetch_context(node_data=self._node_data)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
context = event.context
|
||||
yield event
|
||||
context = event.context
|
||||
yield event
|
||||
if context:
|
||||
node_inputs["#context#"] = context
|
||||
|
||||
@ -221,7 +219,7 @@ class LLMNode(BaseNode):
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
query = None
|
||||
query: str | None = None
|
||||
if self._node_data.memory:
|
||||
query = self._node_data.memory.query_prompt_template
|
||||
if not query and (
|
||||
@ -255,18 +253,31 @@ class LLMNode(BaseNode):
|
||||
structured_output=self._node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self._node_data.reasoning_format,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
if isinstance(event, StreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
# Raw text
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
reasoning_content = event.reasoning_content or ""
|
||||
|
||||
# For downstream nodes, determine clean text based on reasoning_format
|
||||
if self._node_data.reasoning_format == "tagged":
|
||||
# Keep <think> tags for backward compatibility
|
||||
clean_text = result_text
|
||||
else:
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
@ -284,14 +295,26 @@ class LLMNode(BaseNode):
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs is not None:
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
@ -305,8 +328,8 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
@ -316,8 +339,8 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("error while executing llm node")
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
@ -331,14 +354,16 @@ class LLMNode(BaseNode):
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Optional[Mapping[str, Any]] = None,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
)
|
||||
@ -374,6 +399,8 @@ class LLMNode(BaseNode):
|
||||
file_saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
reasoning_format=reasoning_format,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -383,13 +410,16 @@ class LLMNode(BaseNode):
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
event = LLMNode.handle_blocking_result(
|
||||
invoke_result=invoke_result,
|
||||
saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
reasoning_format=reasoning_format,
|
||||
)
|
||||
yield event
|
||||
return
|
||||
@ -414,7 +444,11 @@ class LLMNode(BaseNode):
|
||||
file_outputs=file_outputs,
|
||||
):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=text_part,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
@ -430,13 +464,66 @@ class LLMNode(BaseNode):
|
||||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = full_text_buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _image_file_to_markdown(file: "File", /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning(
|
||||
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Split reasoning content from text based on reasoning_format strategy.
|
||||
|
||||
Args:
|
||||
text: Full text that may contain <think> blocks
|
||||
reasoning_format: Strategy for handling reasoning content
|
||||
- "separated": Remove <think> tags and return clean text + reasoning_content field
|
||||
- "tagged": Keep <think> tags in text, return empty reasoning_content
|
||||
|
||||
Returns:
|
||||
tuple of (clean_text, reasoning_content)
|
||||
"""
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
return text, ""
|
||||
|
||||
# Find all <think>...</think> blocks (case-insensitive)
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
|
||||
# Extract reasoning content from all <think> blocks
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
# Remove all <think>...</think> blocks from original text
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
|
||||
# Clean up extra whitespace
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
# Separated mode: always return clean text and reasoning_content
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@ -629,7 +716,7 @@ class LLMNode(BaseNode):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
@ -811,14 +898,14 @@ class LLMNode(BaseNode):
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list) and all(
|
||||
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
|
||||
):
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
@ -872,7 +959,7 @@ class LLMNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
@ -900,7 +987,7 @@ class LLMNode(BaseNode):
|
||||
def handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
@ -964,6 +1051,7 @@ class LLMNode(BaseNode):
|
||||
invoke_result: LLMResult,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
@ -973,10 +1061,24 @@ class LLMNode(BaseNode):
|
||||
):
|
||||
buffer.write(text_part)
|
||||
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
text=buffer.getvalue(),
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1052,7 +1154,7 @@ class LLMNode(BaseNode):
|
||||
return
|
||||
if isinstance(contents, str):
|
||||
yield contents
|
||||
elif isinstance(contents, list):
|
||||
else:
|
||||
for item in contents:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
yield item.data
|
||||
@ -1066,13 +1168,6 @@ class LLMNode(BaseNode):
|
||||
else:
|
||||
logger.warning("unknown item type encountered, type=%s", type(item))
|
||||
yield str(item)
|
||||
else:
|
||||
logger.warning("unknown contents type encountered, type=%s", type(contents))
|
||||
yield str(contents)
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
@ -1080,7 +1175,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
@ -1089,7 +1184,8 @@ def _combine_message_content_with_role(
|
||||
return AssistantPromptMessage(content=contents)
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=contents)
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
case _:
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def _render_jinja2_message(
|
||||
@ -1185,7 +1281,7 @@ def _handle_memory_completion_mode(
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
) -> Sequence[PromptMessage]:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
@ -35,19 +34,22 @@ class LoopVariableData(BaseModel):
|
||||
label: str
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Optional[Any | list[str]] = None
|
||||
value: Any | list[str] | None = None
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
"""
|
||||
Loop Node Data.
|
||||
"""
|
||||
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("outputs", mode="before")
|
||||
@classmethod
|
||||
def validate_outputs(cls, v):
|
||||
if v is None:
|
||||
return {}
|
||||
return v
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
@ -72,7 +74,7 @@ class LoopState(BaseLoopState):
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
@ -81,7 +83,7 @@ class LoopState(BaseLoopState):
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@ -89,7 +91,7 @@ class LoopState(BaseLoopState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@ -1,27 +1,26 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
|
||||
|
||||
class LoopEndNode(BaseNode):
|
||||
class LoopEndNode(Node):
|
||||
"""
|
||||
Loop End Node.
|
||||
"""
|
||||
|
||||
_node_type = NodeType.LOOP_END
|
||||
node_type = NodeType.LOOP_END
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -30,7 +29,7 @@ class LoopEndNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -1,63 +1,58 @@
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import (
|
||||
IntegerSegment,
|
||||
Segment,
|
||||
SegmentType,
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
BaseParallelBranchEvent,
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunFailedEvent,
|
||||
InNodeEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.node_events import (
|
||||
LoopFailedEvent,
|
||||
LoopNextEvent,
|
||||
LoopStartedEvent,
|
||||
LoopSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(BaseNode):
|
||||
class LoopNode(Node):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
node_type = NodeType.LOOP
|
||||
_node_data: LoopNodeData
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -66,7 +61,7 @@ class LoopNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -79,7 +74,7 @@ class LoopNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> Generator:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
loop_count = self._node_data.loop_count
|
||||
@ -89,144 +84,130 @@ class LoopNode(BaseNode):
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self._node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
|
||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||
|
||||
# Initialize graph
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
root_node_id = self._node_data.start_node_id
|
||||
|
||||
# Initialize variable pool
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
variable_pool.add([self.node_id, "index"], 0)
|
||||
|
||||
# Initialize loop variables
|
||||
# Initialize loop variables in the original variable pool
|
||||
loop_variable_selectors = {}
|
||||
if self._node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
|
||||
if isinstance(var.value, list)
|
||||
else None,
|
||||
}
|
||||
for loop_variable in self._node_data.loop_variables:
|
||||
value_processor = {
|
||||
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var=loop_variable: variable_pool.get(var.value),
|
||||
}
|
||||
|
||||
if loop_variable.value_type not in value_processor:
|
||||
raise ValueError(
|
||||
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
||||
)
|
||||
|
||||
processed_segment = value_processor[loop_variable.value_type]()
|
||||
processed_segment = value_processor[loop_variable.value_type](loop_variable)
|
||||
if not processed_segment:
|
||||
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
||||
variable_selector = [self.node_id, loop_variable.label]
|
||||
variable_pool.add(variable_selector, processed_segment.value)
|
||||
variable_selector = [self._node_id, loop_variable.label]
|
||||
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
|
||||
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_type=self.workflow_type,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=loop_graph,
|
||||
graph_config=self.graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
start_at = naive_utc_now()
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
|
||||
# Start Loop event
|
||||
yield LoopRunStartedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
yield LoopStartedEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"loop_length": loop_count},
|
||||
predecessor_node_id=self.previous_node_id,
|
||||
)
|
||||
|
||||
# yield LoopRunNextEvent(
|
||||
# loop_id=self.id,
|
||||
# loop_node_id=self.node_id,
|
||||
# loop_node_type=self.node_type,
|
||||
# loop_node_data=self.node_data,
|
||||
# index=0,
|
||||
# pre_loop_output=None,
|
||||
# )
|
||||
loop_duration_map = {}
|
||||
single_loop_variable_map = {} # single loop variable output
|
||||
try:
|
||||
check_break_result = False
|
||||
for i in range(loop_count):
|
||||
loop_start_time = naive_utc_now()
|
||||
# run single loop
|
||||
loop_result = yield from self._run_single_loop(
|
||||
graph_engine=graph_engine,
|
||||
loop_graph=loop_graph,
|
||||
variable_pool=variable_pool,
|
||||
loop_variable_selectors=loop_variable_selectors,
|
||||
break_conditions=break_conditions,
|
||||
logical_operator=logical_operator,
|
||||
condition_processor=condition_processor,
|
||||
current_index=i,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
)
|
||||
loop_end_time = naive_utc_now()
|
||||
reach_break_condition = False
|
||||
if break_conditions:
|
||||
with contextlib.suppress(ValueError):
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
|
||||
if reach_break_condition:
|
||||
loop_count = 0
|
||||
cost_tokens = 0
|
||||
|
||||
for i in range(loop_count):
|
||||
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
||||
|
||||
loop_start_time = naive_utc_now()
|
||||
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
|
||||
# Track loop duration
|
||||
loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds()
|
||||
|
||||
# Accumulate outputs from the sub-graph's response nodes
|
||||
for key, value in graph_engine.graph_runtime_state.outputs.items():
|
||||
if key == "answer":
|
||||
# Concatenate answer outputs with newline
|
||||
existing_answer = self.graph_runtime_state.get_output("answer", "")
|
||||
if existing_answer:
|
||||
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
|
||||
else:
|
||||
self.graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
# For other outputs, just update
|
||||
self.graph_runtime_state.set_output(key, value)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
item = variable_pool.get(selector)
|
||||
if item:
|
||||
single_loop_variable[key] = item.value
|
||||
else:
|
||||
single_loop_variable[key] = None
|
||||
segment = self.graph_runtime_state.variable_pool.get(selector)
|
||||
single_loop_variable[key] = segment.value if segment else None
|
||||
|
||||
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
|
||||
single_loop_variable_map[str(i)] = single_loop_variable
|
||||
|
||||
check_break_result = loop_result.get("check_break_result", False)
|
||||
|
||||
if check_break_result:
|
||||
if reach_break_node:
|
||||
break
|
||||
|
||||
if break_conditions:
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if reach_break_condition:
|
||||
break
|
||||
|
||||
yield LoopNextEvent(
|
||||
index=i + 1,
|
||||
pre_loop_output=self._node_data.outputs,
|
||||
)
|
||||
|
||||
self.graph_runtime_state.total_tokens += cost_tokens
|
||||
# Loop completed successfully
|
||||
yield LoopRunSucceededEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
yield LoopSucceededEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self._node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
|
||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
@ -236,18 +217,12 @@ class LoopNode(BaseNode):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Loop failed
|
||||
logger.exception("Loop run failed")
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
yield LoopFailedEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
@ -255,207 +230,60 @@ class LoopNode(BaseNode):
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
variable_pool.remove([self.node_id, "index"])
|
||||
|
||||
def _run_single_loop(
|
||||
self,
|
||||
*,
|
||||
graph_engine: "GraphEngine",
|
||||
loop_graph: Graph,
|
||||
variable_pool: "VariablePool",
|
||||
loop_variable_selectors: dict,
|
||||
break_conditions: list,
|
||||
logical_operator: Literal["and", "or"],
|
||||
condition_processor: ConditionProcessor,
|
||||
current_index: int,
|
||||
start_at: datetime,
|
||||
inputs: dict,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
|
||||
"""Run a single loop iteration.
|
||||
Returns:
|
||||
dict: {'check_break_result': bool}
|
||||
"""
|
||||
# Run workflow
|
||||
rst = graph_engine.run()
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(current_index_variable, IntegerSegment):
|
||||
raise ValueError(f"loop {self.node_id} current index not found")
|
||||
current_index = current_index_variable.value
|
||||
) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]:
|
||||
reach_break_node = False
|
||||
for event in graph_engine.run():
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
self._append_loop_info_to_event(event=event, loop_run_index=current_index)
|
||||
|
||||
check_break_result = False
|
||||
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
||||
event.in_loop_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.LOOP_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START:
|
||||
continue
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
yield event
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END:
|
||||
reach_break_node = True
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
raise Exception(event.error)
|
||||
|
||||
if (
|
||||
isinstance(event, NodeRunSucceededEvent)
|
||||
and event.node_type == NodeType.LOOP_END
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
# Check if variables in break conditions exist and process conditions
|
||||
# Allow loop internal variables to be used in break conditions
|
||||
available_conditions = []
|
||||
for condition in break_conditions:
|
||||
variable = self.graph_runtime_state.variable_pool.get(condition.variable_selector)
|
||||
if variable:
|
||||
available_conditions.append(condition)
|
||||
for loop_var in self._node_data.loop_variables or []:
|
||||
key, sel = loop_var.label, [self._node_id, loop_var.label]
|
||||
segment = self.graph_runtime_state.variable_pool.get(sel)
|
||||
self._node_data.outputs[key] = segment.value if segment else None
|
||||
self._node_data.outputs["loop_round"] = current_index + 1
|
||||
|
||||
# Process conditions if at least one variable is available
|
||||
if available_conditions:
|
||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=available_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if check_break_result:
|
||||
break
|
||||
else:
|
||||
check_break_result = True
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
break
|
||||
return reach_break_node
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
|
||||
graph_engine.graph_runtime_state.total_tokens
|
||||
),
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
|
||||
graph_engine.graph_runtime_state.total_tokens
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
else:
|
||||
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
||||
|
||||
# Remove all nodes outputs from variable pool
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
_outputs: dict[str, Segment | int | None] = {}
|
||||
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||
if _loop_variable_segment:
|
||||
_outputs[loop_variable_key] = _loop_variable_segment
|
||||
else:
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
_outputs["loop_round"] = current_index + 1
|
||||
self._node_data.outputs = _outputs
|
||||
|
||||
if check_break_result:
|
||||
return {"check_break_result": True}
|
||||
|
||||
# Move to next loop
|
||||
next_index = current_index + 1
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=self._node_data.outputs,
|
||||
)
|
||||
|
||||
return {"check_break_result": False}
|
||||
|
||||
def _handle_event_metadata(
|
||||
def _append_loop_info_to_event(
|
||||
self,
|
||||
*,
|
||||
event: BaseNodeEvent | InNodeEvent,
|
||||
iter_run_index: int,
|
||||
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
||||
"""
|
||||
add iteration metadata to event.
|
||||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
|
||||
metadata = {
|
||||
**metadata,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
|
||||
}
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
return event
|
||||
event: GraphNodeEventBase,
|
||||
loop_run_index: int,
|
||||
):
|
||||
event.in_loop_id = self._node_id
|
||||
loop_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index,
|
||||
}
|
||||
|
||||
current_metadata = event.node_run_result.metadata
|
||||
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
|
||||
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@ -470,13 +298,13 @@ class LoopNode(BaseNode):
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# init graph
|
||||
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
for sub_node_id, sub_node_config in loop_graph.node_id_config_mapping.items():
|
||||
# Get node configs from graph_config
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
continue
|
||||
|
||||
@ -515,12 +343,35 @@ class LoopNode(BaseNode):
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
|
||||
}
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
This method statically analyzes the graph configuration to find all nodes
|
||||
that are part of the specified loop, without creating actual node instances.
|
||||
|
||||
:param graph_config: the complete graph configuration
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
loop_node_ids.add(node_id)
|
||||
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
@ -552,3 +403,47 @@ class LoopNode(BaseNode):
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
invoke_from=self.invoke_from.value,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
|
||||
# Create a new GraphRuntimeState for this iteration
|
||||
graph_runtime_state_copy = GraphRuntimeState(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
start_at=start_at.timestamp(),
|
||||
)
|
||||
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the loop graph with the new node factory
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=loop_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
return graph_engine
|
||||
|
||||
@ -1,27 +1,26 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
|
||||
|
||||
class LoopStartNode(BaseNode):
|
||||
class LoopStartNode(Node):
|
||||
"""
|
||||
Loop Start Node.
|
||||
"""
|
||||
|
||||
_node_type = NodeType.LOOP_START
|
||||
node_type = NodeType.LOOP_START
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -30,7 +29,7 @@ class LoopStartNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
88
api/core/workflow/nodes/node_factory.py
Normal file
88
api/core/workflow/nodes/node_factory.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Default implementation of NodeFactory that uses the traditional node mapping.
|
||||
|
||||
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
|
||||
and instantiating the appropriate node class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config.get("id")
|
||||
if not is_str(node_id):
|
||||
raise ValueError("Node config missing id")
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
|
||||
node_type_str = node_data.get("type")
|
||||
if not is_str(node_type_str):
|
||||
raise ValueError(f"Node {node_id} missing or invalid type information")
|
||||
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_type_str}")
|
||||
|
||||
# Get node class
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
node_class = node_mapping.get(LATEST_VERSION)
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
# Create node instance
|
||||
node_instance = node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node with provided data
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
node_instance.init_node_data(node_data)
|
||||
|
||||
# If node has fail branch, change execution type to branch
|
||||
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node_instance.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
return node_instance
|
||||
@ -1,15 +1,17 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end import EndNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
@ -33,7 +35,7 @@ LATEST_VERSION = "latest"
|
||||
#
|
||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||
# hook. Try to avoid duplication of node information.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
"1": StartNode,
|
||||
@ -135,6 +137,14 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
NodeType.TRIGGER_WEBHOOK: {
|
||||
LATEST_VERSION: TriggerWebhookNode,
|
||||
"1": TriggerWebhookNode,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@ -31,8 +31,6 @@ _VALID_PARAMETER_TYPES = frozenset(
|
||||
|
||||
|
||||
def _validate_type(parameter_type: str) -> SegmentType:
|
||||
if not isinstance(parameter_type, str):
|
||||
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
|
||||
if parameter_type not in _VALID_PARAMETER_TYPES:
|
||||
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
||||
|
||||
@ -43,10 +41,6 @@ def _validate_type(parameter_type: str) -> SegmentType:
|
||||
return SegmentType(parameter_type)
|
||||
|
||||
|
||||
class _ParameterConfigError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ParameterConfig(BaseModel):
|
||||
"""
|
||||
Parameter Config.
|
||||
@ -54,7 +48,7 @@ class ParameterConfig(BaseModel):
|
||||
|
||||
name: str
|
||||
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
||||
options: Optional[list[str]] = None
|
||||
options: list[str] | None = None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
@ -92,8 +86,8 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
instruction: str | None = None
|
||||
memory: MemoryConfig | None = None
|
||||
reasoning_mode: Literal["function_call", "prompt"]
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@ -102,7 +96,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
|
||||
def get_parameter_json_schema(self) -> dict:
|
||||
def get_parameter_json_schema(self):
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
) -> None:
|
||||
):
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
|
||||
@ -3,14 +3,14 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -27,19 +27,17 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
@ -52,6 +50,7 @@ from .exc import (
|
||||
)
|
||||
from .prompts import (
|
||||
CHAT_EXAMPLE,
|
||||
CHAT_GENERATE_JSON_PROMPT,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||
COMPLETION_GENERATE_JSON_PROMPT,
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
|
||||
@ -85,19 +84,19 @@ def extract_json(text):
|
||||
return None
|
||||
|
||||
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
class ParameterExtractorNode(Node):
|
||||
"""
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
|
||||
_node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -106,7 +105,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -115,11 +114,11 @@ class ParameterExtractorNode(BaseNode):
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
@ -294,7 +293,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
@ -305,8 +304,6 @@ class ParameterExtractorNode(BaseNode):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
if not isinstance(invoke_result, LLMResult):
|
||||
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
text = invoke_result.message.content or ""
|
||||
if not isinstance(text, str):
|
||||
@ -318,9 +315,6 @@ class ParameterExtractorNode(BaseNode):
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
if text is None:
|
||||
text = ""
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(
|
||||
@ -329,9 +323,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
"""
|
||||
Generate function call prompt.
|
||||
@ -411,9 +405,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
@ -449,9 +443,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate completion prompt.
|
||||
@ -483,9 +477,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate chat prompt.
|
||||
@ -545,7 +539,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
@ -584,20 +578,21 @@ class ParameterExtractorNode(BaseNode):
|
||||
return int(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
return None
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
elif isinstance(value, str):
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
@ -656,7 +651,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
|
||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
@ -671,7 +666,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
@ -690,7 +685,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
@ -698,7 +693,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
elif parameter.type == "bool":
|
||||
elif parameter.type == "boolean":
|
||||
result[parameter.name] = False
|
||||
elif parameter.type in {"string", "select"}:
|
||||
result[parameter.name] = ""
|
||||
@ -710,7 +705,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -737,7 +732,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -752,7 +747,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
@ -773,7 +768,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData):
|
||||
query_variable_selector: list[str]
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
instruction: str | None = None
|
||||
memory: MemoryConfig | None = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -10,21 +10,20 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.llm import (
|
||||
LLMNode,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
llm_utils,
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
@ -41,11 +40,12 @@ from .template_prompts import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
|
||||
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
class QuestionClassifierNode(Node):
|
||||
node_type = NodeType.QUESTION_CLASSIFIER
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_node_data: QuestionClassifierNodeData
|
||||
|
||||
@ -57,21 +57,15 @@ class QuestionClassifierNode(BaseNode):
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
@ -83,10 +77,10 @@ class QuestionClassifierNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -95,7 +89,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -187,7 +181,8 @@ class QuestionClassifierNode(BaseNode):
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
@ -259,6 +254,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
|
||||
|
||||
@ -275,12 +271,13 @@ class QuestionClassifierNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:param filters: filter by node config parameters (not used in this implementation).
|
||||
:return:
|
||||
"""
|
||||
# filters parameter is not used in this node type
|
||||
return {"type": "question-classifier", "config": {"instructions": ""}}
|
||||
|
||||
def _calculate_rest_token(
|
||||
@ -288,7 +285,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
@ -331,7 +328,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@ -1,24 +1,24 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
class StartNode(BaseNode):
|
||||
_node_type = NodeType.START
|
||||
class StartNode(Node):
|
||||
node_type = NodeType.START
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -27,7 +27,7 @@ class StartNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class TemplateTransformNodeData(BaseNodeData):
|
||||
|
||||
@ -1,27 +1,26 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
|
||||
|
||||
class TemplateTransformNode(BaseNode):
|
||||
_node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
class TemplateTransformNode(Node):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -30,7 +29,7 @@ class TemplateTransformNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -40,7 +39,7 @@ class TemplateTransformNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
@ -57,7 +56,7 @@ class TemplateTransformNode(BaseNode):
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
variables = {}
|
||||
variables: dict[str, Any] = {}
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
|
||||
@ -1,28 +1,28 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
@ -35,27 +35,31 @@ from .exc import (
|
||||
ToolParameterError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
|
||||
class ToolNode(Node):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
|
||||
_node_type = NodeType.TOOL
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
node_data = self._node_data
|
||||
|
||||
@ -78,11 +82,11 @@ class ToolNode(BaseNode):
|
||||
if node_data.version != "1" or node_data.tool_node_version != "1":
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
|
||||
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
@ -115,13 +119,12 @@ class ToolNode(BaseNode):
|
||||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
@ -139,11 +142,11 @@ class ToolNode(BaseNode):
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_id=self.node_id,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
except ToolInvokeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
@ -152,8 +155,8 @@ class ToolNode(BaseNode):
|
||||
)
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
@ -165,8 +168,8 @@ class ToolNode(BaseNode):
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
@ -179,7 +182,7 @@ class ToolNode(BaseNode):
|
||||
self,
|
||||
*,
|
||||
tool_parameters: Sequence[ToolParameter],
|
||||
variable_pool: VariablePool,
|
||||
variable_pool: "VariablePool",
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
@ -220,7 +223,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
@ -238,6 +241,8 @@ class ToolNode(BaseNode):
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
@ -310,17 +315,25 @@ class ToolNode(BaseNode):
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
# JSON message handling for tool node
|
||||
if message.message.json_object is not None:
|
||||
if message.message.json_object:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
@ -332,8 +345,10 @@ class ToolNode(BaseNode):
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
@ -393,8 +408,24 @@ class ToolNode(BaseNode):
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
@ -431,7 +462,8 @@ class ToolNode(BaseNode):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
@ -439,7 +471,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
return result
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -448,7 +480,7 @@ class ToolNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -457,10 +489,6 @@ class ToolNode(BaseNode):
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData):
|
||||
type: str = "variable-assigner"
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: Optional[AdvancedSettings] = None
|
||||
advanced_settings: AdvancedSettings | None = None
|
||||
|
||||
@ -1,24 +1,23 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
class VariableAggregatorNode(Node):
|
||||
node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -27,7 +26,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -16,7 +16,7 @@ class UpdatedVariable(BaseModel):
|
||||
name: str
|
||||
selector: Sequence[str]
|
||||
value_type: SegmentType
|
||||
new_value: Any
|
||||
new_value: Any = None
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||
@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
node_id, var_name = selector[:2]
|
||||
_, var_name = selector[:2]
|
||||
return UpdatedVariable(
|
||||
name=var_name,
|
||||
selector=list(selector[:2]),
|
||||
|
||||
@ -1,29 +1,19 @@
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from models.engine import db
|
||||
from models.workflow import ConversationVariable
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
from .exc import VariableOperatorNodeError
|
||||
|
||||
|
||||
class ConversationVariableUpdaterImpl:
|
||||
_engine: Engine | None
|
||||
|
||||
def __init__(self, engine: Engine | None = None) -> None:
|
||||
self._engine = engine
|
||||
|
||||
def _get_engine(self) -> Engine:
|
||||
if self._engine:
|
||||
return self._engine
|
||||
return db.engine
|
||||
|
||||
def update(self, conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(self._get_engine()) as session:
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
@ -18,22 +18,22 @@ from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
class VariableAssignerNode(Node):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
_node_data: VariableAssignerData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -42,7 +42,7 @@ class VariableAssignerNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -56,20 +56,14 @@ class VariableAssignerNode(BaseNode):
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
self._conv_var_updater_factory = conv_var_updater_factory
|
||||
|
||||
@ -123,13 +117,8 @@ class VariableAssignerNode(BaseNode):
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
if income_value is None:
|
||||
raise VariableOperatorNodeError("income value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
@ -18,9 +18,9 @@ class VariableOperationItem(BaseModel):
|
||||
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
|
||||
# 3. During the variable updating procedure: The `value` field is reassigned to hold
|
||||
# the resolved actual value that will be applied to the target variable.
|
||||
value: Any | None = None
|
||||
value: Any = None
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
version: str = "2"
|
||||
items: Sequence[VariableOperationItem]
|
||||
items: Sequence[VariableOperationItem] = Field(default_factory=list)
|
||||
|
||||
@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
|
||||
|
||||
class InvalidDataError(VariableOperatorNodeError):
|
||||
def __init__(self, message: str) -> None:
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
@ -25,8 +25,6 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
# Only array variable can be appended or extended
|
||||
# Only array variable can have elements removed
|
||||
return variable_type.is_array_type()
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def is_variable_input_supported(*, operation: Operation):
|
||||
|
||||
@ -1,17 +1,16 @@
|
||||
import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
@ -53,15 +52,15 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||
mapping[key] = selector
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
class VariableAssignerNode(Node):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -70,7 +69,7 @@ class VariableAssignerNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -79,6 +78,23 @@ class VariableAssignerNode(BaseNode):
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
|
||||
Returns True if this node updates any of the requested conversation variables.
|
||||
"""
|
||||
# Check each item in this Variable Assigner node
|
||||
for item in self._node_data.items:
|
||||
# Convert the item's variable_selector to tuple for comparison
|
||||
item_selector_tuple = tuple(item.variable_selector)
|
||||
|
||||
# Check if this item updates any of the requested variables
|
||||
if item_selector_tuple in variable_selectors:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
@ -258,5 +274,3 @@ class VariableAssignerNode(BaseNode):
|
||||
if not variable.value:
|
||||
return variable.value
|
||||
return variable.value[:-1]
|
||||
case _:
|
||||
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)
|
||||
|
||||
Reference in New Issue
Block a user