Merge branch 'main' into feat/support-knowledge-metadata

# Conflicts:
#	api/core/rag/datasource/retrieval_service.py
#	api/core/workflow/nodes/code/code_node.py
#	api/services/dataset_service.py
This commit is contained in:
jyong
2025-02-26 19:59:14 +08:00
3498 changed files with 74030 additions and 293893 deletions

View File

@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum):
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
PARALLEL_ID = "parallel_id"
@ -48,3 +49,8 @@ class NodeRunResult(BaseModel):
# single step node run retry
retry_index: int = 0
class AgentNodeStrategyInit(BaseModel):
name: str
icon: str | None = None

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
@ -66,8 +67,10 @@ class BaseNodeEvent(GraphEngineEvent):
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
"""predecessor node id"""
parallel_mode_run_id: Optional[str] = None
"""iteration node parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None
class NodeRunStreamChunkEvent(BaseNodeEvent):
@ -164,8 +167,8 @@ class IterationRunStartedEvent(BaseIterationEvent):
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
duration: Optional[float] = Field(None, description="duration")
pre_iteration_output: Optional[Any] = None
duration: Optional[float] = None
class IterationRunSucceededEvent(BaseIterationEvent):
@ -186,4 +189,24 @@ class IterationRunFailedEvent(BaseIterationEvent):
error: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
###########################################
# Agent Events
###########################################
class BaseAgentEvent(GraphEngineEvent):
pass
class AgentLogEvent(BaseAgentEvent):
id: str = Field(..., description="id")
label: str = Field(..., description="label")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent

View File

@ -590,6 +590,8 @@ class Graph(BaseModel):
start_node_id=node_id,
routes_node_ids=routes_node_ids,
)
# Exclude conditional branch nodes
and all(edge.run_condition is None for edge in reverse_edge_mapping.get(node_id, []))
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []

View File

@ -1,3 +1,4 @@
import contextvars
import logging
import queue
import time
@ -13,7 +14,7 @@ from flask import Flask, current_app
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
@ -39,6 +40,8 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode
@ -477,6 +480,7 @@ class GraphEngine:
**{
"flask_app": current_app._get_current_object(), # type: ignore[attr-defined]
"q": q,
"context": contextvars.copy_context(),
"parallel_id": parallel_id,
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
@ -520,6 +524,7 @@ class GraphEngine:
def _run_parallel_node(
self,
flask_app: Flask,
context: contextvars.Context,
q: queue.Queue,
parallel_id: str,
parallel_start_node_id: str,
@ -530,6 +535,9 @@ class GraphEngine:
"""
Run parallel nodes
"""
for var, val in context.items():
var.set(val)
with flask_app.app_context():
try:
q.put(
@ -600,6 +608,14 @@ class GraphEngine:
Run node
"""
# trigger node run start event
agent_strategy = (
AgentNodeStrategyInit(
name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
icon=cast(AgentNode, node_instance).agent_strategy_icon,
)
if node_instance.node_type == NodeType.AGENT
else None
)
yield NodeRunStartedEvent(
id=node_instance.id,
node_id=node_instance.node_id,
@ -611,6 +627,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
)
db.session.close()

View File

@ -0,0 +1,3 @@
from .agent_node import AgentNode
__all__ = ["AgentNode"]

View File

@ -0,0 +1,254 @@
from ast import literal_eval
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from factories.agent_factory import get_plugin_agent_strategy
from models.workflow import WorkflowNodeExecutionStatus
class AgentNode(ToolNode):
"""
Agent Node
"""
_node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT
def _run(self) -> Generator:
"""
Run the agent node
"""
node_data = cast(AgentNodeData, self.node_data)
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
agent_strategy_provider_name=node_data.agent_strategy_provider_name,
agent_strategy_name=node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
)
)
return
agent_parameters = strategy.get_parameters()
# get parameters
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
for_log=True,
)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=self.user_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=f"Failed to invoke agent: {str(e)}",
)
)
return
try:
# convert tool messages
yield from self._transform_message(
message_stream,
{
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
parameters_for_log,
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=f"Failed to transform agent message: {str(e)}",
)
)
def _generate_agent_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (AgentNodeData): The data associated with the agent node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise ValueError(f"Variable {agent_input.value} does not exist")
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(agent_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ValueError(f"Unknown agent input type '{agent_input.type}'")
value = parameter_value.strip()
if (parameter_value.startswith("{") and parameter_value.endswith("}")) or (
parameter_value.startswith("[") and parameter_value.endswith("]")
):
value = literal_eval(parameter_value) # transform string to python object
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=ToolProviderType.BUILT_IN,
tool_name=tool.get("tool_name", ""),
tool_parameters=tool.get("parameters", {}),
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
)
extra = tool.get("extra", {})
tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("descrption", "") or tool_runtime.entity.description.llm
)
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
}
)
value = tool_value
if parameter.type == "model-selector":
value = cast(dict[str, Any], value)
model_instance = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
provider=value.get("provider", ""),
model_type=ModelType(value.get("model_type", "")),
model=value.get("model", ""),
)
models = model_instance.model_type_instance.plugin_model_provider.declaration.models
finded_model = next((model for model in models if model.model == value.get("model", "")), None)
value["entity"] = finded_model.model_dump(mode="json") if finded_model else None
result[parameter_name] = value
return result
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: BaseNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(AgentNodeData, node_data)
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result
@property
def agent_strategy_icon(self) -> str | None:
"""
Get agent strategy icon
:return:
"""
manager = PluginInstallationManager()
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
icon = None
return icon

View File

@ -0,0 +1,18 @@
from typing import Any, Literal, Union
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData
class AgentNodeData(BaseNodeData):
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str # redundancy
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
agent_parameters: dict[str, AgentInput]

View File

@ -195,7 +195,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if output_config.type == "object":
# check if output is object
if not isinstance(result.get(output_name), dict):
if result.get(output_name) is None:
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(

View File

@ -1,6 +1,3 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
@ -30,20 +27,3 @@ class EndNode(BaseNode[EndNodeData]):
inputs=outputs,
outputs=outputs,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: EndNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@ -22,6 +22,7 @@ class NodeType(StrEnum):
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
class ErrorStrategy(StrEnum):

View File

@ -1,5 +1,4 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Literal
from typing_extensions import deprecated
@ -88,23 +87,6 @@ class IfElseNode(BaseNode[IfElseNodeData]):
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}
@deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function(

View File

@ -1,3 +1,4 @@
import contextvars
import logging
import uuid
from collections.abc import Generator, Mapping, Sequence
@ -174,6 +175,7 @@ class IterationNode(BaseNode[IterationNodeData]):
self._run_single_iter_parallel,
flask_app=current_app._get_current_object(), # type: ignore
q=q,
context=contextvars.copy_context(),
iterator_list_value=iterator_list_value,
inputs=inputs,
outputs=outputs,
@ -568,6 +570,7 @@ class IterationNode(BaseNode[IterationNodeData]):
self,
*,
flask_app: Flask,
context: contextvars.Context,
q: Queue,
iterator_list_value: Sequence[str],
inputs: Mapping[str, list],
@ -582,9 +585,12 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
run single iteration in parallel mode
"""
for var, val in context.items():
var.set(val)
with flask_app.app_context():
parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy()
graph_engine_copy.graph_runtime_state.total_tokens = 0
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
variable_pool_copy.add([self.node_id, "index"], index)
variable_pool_copy.add([self.node_id, "item"], item)

View File

@ -1,10 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
from core.workflow.nodes.iteration.entities import IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -21,16 +18,3 @@ class IterationStartNode(BaseNode):
Run the node.
"""
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@ -1,6 +1,7 @@
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast
from configs import dify_config
@ -29,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import (
@ -236,9 +238,9 @@ class LLMNode(BaseNode[LLMNodeData]):
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
stop=stop,
stop=list(stop or []),
stream=True,
user=self.user_id,
)
@ -758,11 +760,17 @@ class LLMNode(BaseNode[LLMNodeData]):
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_instance.provider,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update({"quota_used": Provider.quota_used + used_quota})
).update(
{
"quota_used": Provider.quota_used + used_quota,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
}
)
db.session.commit()
@classmethod

View File

@ -1,5 +1,6 @@
from collections.abc import Mapping
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
@ -101,4 +102,8 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
"1": AgentNode,
},
}

View File

@ -244,8 +244,8 @@ class ParameterExtractorNode(LLMNode):
if not isinstance(invoke_result, LLMResult):
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
if not isinstance(text, str | None):
text = invoke_result.message.content or ""
if not isinstance(text, str):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage

View File

@ -1,6 +1,3 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
@ -23,13 +20,3 @@ class StartNode(BaseNode[StartNodeData]):
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: StartNodeData,
) -> Mapping[str, Sequence[str]]:
return {}

View File

@ -3,16 +3,18 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.workflow.nodes.base import BaseNodeData
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.base.entities import BaseNodeData
class ToolEntity(BaseModel):
provider_id: str
provider_type: Literal["builtin", "api", "workflow"]
provider_type: ToolProviderType
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod

View File

@ -1,24 +1,31 @@
from collections.abc import Mapping, Sequence
from typing import Any
from uuid import UUID
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod, FileType
from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
from .exc import (
@ -36,11 +43,18 @@ class ToolNode(BaseNode[ToolNodeData]):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self) -> NodeRunResult:
def _run(self) -> Generator:
"""
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {
"provider_type": self.node_data.provider_type,
"provider_id": self.node_data.provider_id,
"provider_type": node_data.provider_type.value,
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
# get tool runtime
@ -51,18 +65,19 @@ class ToolNode(BaseNode[ToolNodeData]):
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
except ToolNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
)
)
return
# get parameters
tool_parameters = tool_runtime.parameters or []
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
@ -75,52 +90,45 @@ class ToolNode(BaseNode[ToolNodeData]):
for_log=True,
)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
messages = ToolEngine.workflow_invoke(
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except ToolNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
error_type="UnknownError",
return
try:
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}",
)
)
# convert tool messages
plain_text, files, json = self._convert_tool_messages(messages)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": plain_text,
"files": files,
"json": json,
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
inputs=parameters_for_log,
)
def _generate_parameters(
self,
*,
@ -128,7 +136,7 @@ class ToolNode(BaseNode[ToolNodeData]):
variable_pool: VariablePool,
node_data: ToolNodeData,
for_log: bool = False,
) -> Mapping[str, Any]:
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -164,37 +172,52 @@ class ToolNode(BaseNode[ToolNodeData]):
return result
def _convert_tool_messages(
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _transform_message(
self,
messages: list[ToolInvokeMessage],
):
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
# extract plain text and files
files = self._extract_tool_response_binary(messages)
plain_text = self._extract_tool_response_text(messages)
json = self._extract_tool_response_json(messages)
return plain_text, files, json
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[File]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = str(response.message) if response.message else None
tool_file_id = str(url).split("/")[-1].split(".")[0]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
@ -204,7 +227,7 @@ class ToolNode(BaseNode[ToolNodeData]):
mapping = {
"tool_file_id": tool_file_id,
"type": FileType.IMAGE,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
@ -212,70 +235,139 @@ class ToolNode(BaseNode[ToolNodeData]):
mapping=mapping,
tenant_id=self.tenant_id,
)
result.append(file)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
result.append(file)
elif response.type == ToolInvokeMessage.MessageType.LINK:
url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = url.split("/")[-1].split(".")[0]
try:
UUID(tool_file_id)
except ValueError:
raise ToolFileError(f"cannot extract tool file id from url {url}")
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
result.append(file)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value for key, value in msg_metadata.items() if key in NodeRunMetadataKey
}
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
elif response.type == ToolInvokeMessage.MessageType.FILE:
assert response.meta is not None
result.append(response.meta["file"])
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstallationManager()
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
self.user_id,
self.tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
except StopIteration:
pass
return result
dict_metadata["icon"] = icon
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=self.id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
)
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
return "\n".join(
[
str(message.message)
if message.type == ToolInvokeMessage.MessageType.TEXT
else f"Link: {str(message.message)}"
for message in tool_response
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}
]
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables},
metadata={
**agent_execution_metadata,
NodeRunMetadataKey.TOOL_INFO: tool_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
)
)
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]):
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -295,7 +387,8 @@ class ToolNode(BaseNode[ToolNodeData]):
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == "mixed":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":

View File

@ -1,6 +1,3 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
@ -36,16 +33,3 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
break
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from typing import Any, Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
@ -194,6 +194,105 @@ class WorkflowEntry:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator
@classmethod
def run_free_node(
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
"""
Run free node
NOTE: only parameter_extractor/question_classifier are supported
:param node_data: node data
:param user_id: user id
:param user_inputs: user inputs
:return:
"""
# generate a fake graph
node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
start_node_config = {
"id": "start",
"width": 114,
"height": 514,
"type": "custom",
"data": {
"type": NodeType.START.value,
"title": "Start",
"desc": "Start",
},
}
graph_dict = {
"nodes": [start_node_config, node_config],
"edges": [
{
"source": "start",
"target": node_id,
"sourceHandle": "source",
"targetHandle": "target",
}
],
}
node_type = NodeType(node_data.get("type", ""))
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
graph = Graph.init(graph_config=graph_dict)
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=[],
)
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node_instance: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
tenant_id=tenant_id,
app_id="",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="",
graph_config=graph_dict,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
)
try:
# variable selector to variable mapping
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_dict, config=node_config
)
except NotImplementedError:
variable_mapping = {}
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=tenant_id,
)
# run node
generator = node_instance.run()
return node_instance, generator
except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
result = WorkflowEntry._handle_special_values(value)