mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 15:26:21 +08:00
504 lines
21 KiB
Python
504 lines
21 KiB
Python
from collections.abc import Generator, Mapping, Sequence
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
|
from core.tools.__base.tool import Tool
|
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
|
from core.tools.errors import ToolInvokeError
|
|
from core.tools.tool_engine import ToolEngine
|
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
|
from dify_graph.enums import (
|
|
NodeType,
|
|
SystemVariableKey,
|
|
WorkflowNodeExecutionMetadataKey,
|
|
WorkflowNodeExecutionStatus,
|
|
)
|
|
from dify_graph.file import File, FileTransferMethod
|
|
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
|
from dify_graph.nodes.base.node import Node
|
|
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
|
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
|
|
from dify_graph.variables.variables import ArrayAnyVariable
|
|
from extensions.ext_database import db
|
|
from factories import file_factory
|
|
from models import ToolFile
|
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
|
|
|
from .entities import ToolNodeData
|
|
from .exc import (
|
|
ToolFileError,
|
|
ToolNodeError,
|
|
ToolParameterError,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from dify_graph.runtime import VariablePool
|
|
|
|
|
|
class ToolNode(Node[ToolNodeData]):
|
|
"""
|
|
Tool Node
|
|
"""
|
|
|
|
node_type = NodeType.TOOL
|
|
|
|
@classmethod
|
|
def version(cls) -> str:
|
|
return "1"
|
|
|
|
def _run(self) -> Generator[NodeEventBase, None, None]:
|
|
"""
|
|
Run the tool node
|
|
"""
|
|
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
|
|
|
# fetch tool icon
|
|
tool_info = {
|
|
"provider_type": self.node_data.provider_type.value,
|
|
"provider_id": self.node_data.provider_id,
|
|
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
|
}
|
|
|
|
# get tool runtime
|
|
try:
|
|
from core.tools.tool_manager import ToolManager
|
|
|
|
# This is an issue that caused problems before.
|
|
# Logically, we shouldn't use the node_data.version field for judgment
|
|
# But for backward compatibility with historical data
|
|
# this version field judgment is still preserved here.
|
|
variable_pool: VariablePool | None = None
|
|
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
|
variable_pool = self.graph_runtime_state.variable_pool
|
|
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
|
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
|
|
)
|
|
except ToolNodeError as e:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs={},
|
|
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
|
error=f"Failed to get tool runtime: {str(e)}",
|
|
error_type=type(e).__name__,
|
|
)
|
|
)
|
|
return
|
|
|
|
# get parameters
|
|
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
|
parameters = self._generate_parameters(
|
|
tool_parameters=tool_parameters,
|
|
variable_pool=self.graph_runtime_state.variable_pool,
|
|
node_data=self.node_data,
|
|
)
|
|
parameters_for_log = self._generate_parameters(
|
|
tool_parameters=tool_parameters,
|
|
variable_pool=self.graph_runtime_state.variable_pool,
|
|
node_data=self.node_data,
|
|
for_log=True,
|
|
)
|
|
# get conversation id
|
|
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
|
|
|
try:
|
|
message_stream = ToolEngine.generic_invoke(
|
|
tool=tool_runtime,
|
|
tool_parameters=parameters,
|
|
user_id=self.user_id,
|
|
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
|
workflow_call_depth=self.workflow_call_depth,
|
|
app_id=self.app_id,
|
|
conversation_id=conversation_id.text if conversation_id else None,
|
|
)
|
|
except ToolNodeError as e:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
|
error=f"Failed to invoke tool: {str(e)}",
|
|
error_type=type(e).__name__,
|
|
)
|
|
)
|
|
return
|
|
|
|
try:
|
|
# convert tool messages
|
|
_ = yield from self._transform_message(
|
|
messages=message_stream,
|
|
tool_info=tool_info,
|
|
parameters_for_log=parameters_for_log,
|
|
user_id=self.user_id,
|
|
tenant_id=self.tenant_id,
|
|
node_id=self._node_id,
|
|
tool_runtime=tool_runtime,
|
|
)
|
|
except ToolInvokeError as e:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
|
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
|
|
error_type=type(e).__name__,
|
|
)
|
|
)
|
|
except PluginInvokeError as e:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
|
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
|
|
error_type=type(e).__name__,
|
|
)
|
|
)
|
|
except PluginDaemonClientSideError as e:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
|
error=f"Failed to invoke tool, error: {e.description}",
|
|
error_type=type(e).__name__,
|
|
)
|
|
)
|
|
|
|
def _generate_parameters(
|
|
self,
|
|
*,
|
|
tool_parameters: Sequence[ToolParameter],
|
|
variable_pool: "VariablePool",
|
|
node_data: ToolNodeData,
|
|
for_log: bool = False,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Generate parameters based on the given tool parameters, variable pool, and node data.
|
|
|
|
Args:
|
|
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
|
variable_pool (VariablePool): The variable pool containing the variables.
|
|
node_data (ToolNodeData): The data associated with the tool node.
|
|
|
|
Returns:
|
|
Mapping[str, Any]: A dictionary containing the generated parameters.
|
|
|
|
"""
|
|
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
|
|
|
result: dict[str, Any] = {}
|
|
for parameter_name in node_data.tool_parameters:
|
|
parameter = tool_parameters_dictionary.get(parameter_name)
|
|
if not parameter:
|
|
result[parameter_name] = None
|
|
continue
|
|
tool_input = node_data.tool_parameters[parameter_name]
|
|
if tool_input.type == "variable":
|
|
variable = variable_pool.get(tool_input.value)
|
|
if variable is None:
|
|
if parameter.required:
|
|
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
|
continue
|
|
parameter_value = variable.value
|
|
elif tool_input.type in {"mixed", "constant"}:
|
|
segment_group = variable_pool.convert_template(str(tool_input.value))
|
|
parameter_value = segment_group.log if for_log else segment_group.text
|
|
else:
|
|
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
|
result[parameter_name] = parameter_value
|
|
|
|
return result
|
|
|
|
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
|
|
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
|
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
|
return list(variable.value) if variable else []
|
|
|
|
def _transform_message(
|
|
self,
|
|
messages: Generator[ToolInvokeMessage, None, None],
|
|
tool_info: Mapping[str, Any],
|
|
parameters_for_log: dict[str, Any],
|
|
user_id: str,
|
|
tenant_id: str,
|
|
node_id: str,
|
|
tool_runtime: Tool,
|
|
) -> Generator[NodeEventBase, None, LLMUsage]:
|
|
"""
|
|
Convert ToolInvokeMessages into tuple[plain_text, files]
|
|
"""
|
|
# transform message and handle file storage
|
|
from core.plugin.impl.plugin import PluginInstaller
|
|
|
|
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
|
messages=messages,
|
|
user_id=user_id,
|
|
tenant_id=tenant_id,
|
|
conversation_id=None,
|
|
)
|
|
|
|
text = ""
|
|
files: list[File] = []
|
|
json: list[dict | list] = []
|
|
|
|
variables: dict[str, Any] = {}
|
|
|
|
for message in message_stream:
|
|
if message.type in {
|
|
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
|
ToolInvokeMessage.MessageType.BINARY_LINK,
|
|
ToolInvokeMessage.MessageType.IMAGE,
|
|
}:
|
|
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
|
|
|
url = message.message.text
|
|
if message.meta:
|
|
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
|
else:
|
|
transfer_method = FileTransferMethod.TOOL_FILE
|
|
|
|
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
|
|
|
with Session(db.engine) as session:
|
|
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
|
tool_file = session.scalar(stmt)
|
|
if tool_file is None:
|
|
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
|
|
|
mapping = {
|
|
"tool_file_id": tool_file_id,
|
|
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
|
"transfer_method": transfer_method,
|
|
"url": url,
|
|
}
|
|
file = file_factory.build_from_mapping(
|
|
mapping=mapping,
|
|
tenant_id=tenant_id,
|
|
)
|
|
files.append(file)
|
|
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
|
# get tool file id
|
|
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
|
assert message.meta
|
|
|
|
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
|
with Session(db.engine) as session:
|
|
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
|
tool_file = session.scalar(stmt)
|
|
if tool_file is None:
|
|
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
|
|
|
mapping = {
|
|
"tool_file_id": tool_file_id,
|
|
"transfer_method": FileTransferMethod.TOOL_FILE,
|
|
}
|
|
|
|
files.append(
|
|
file_factory.build_from_mapping(
|
|
mapping=mapping,
|
|
tenant_id=tenant_id,
|
|
)
|
|
)
|
|
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
|
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
|
text += message.message.text
|
|
yield StreamChunkEvent(
|
|
selector=[node_id, "text"],
|
|
chunk=message.message.text,
|
|
is_final=False,
|
|
)
|
|
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
|
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
|
# JSON message handling for tool node
|
|
if message.message.json_object:
|
|
json.append(message.message.json_object)
|
|
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
|
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
|
|
|
# Check if this LINK message is a file link
|
|
file_obj = (message.meta or {}).get("file")
|
|
if isinstance(file_obj, File):
|
|
files.append(file_obj)
|
|
stream_text = f"File: {message.message.text}\n"
|
|
else:
|
|
stream_text = f"Link: {message.message.text}\n"
|
|
|
|
text += stream_text
|
|
yield StreamChunkEvent(
|
|
selector=[node_id, "text"],
|
|
chunk=stream_text,
|
|
is_final=False,
|
|
)
|
|
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
|
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
|
variable_name = message.message.variable_name
|
|
variable_value = message.message.variable_value
|
|
if message.message.stream:
|
|
if not isinstance(variable_value, str):
|
|
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
|
|
if variable_name not in variables:
|
|
variables[variable_name] = ""
|
|
variables[variable_name] += variable_value
|
|
|
|
yield StreamChunkEvent(
|
|
selector=[node_id, variable_name],
|
|
chunk=variable_value,
|
|
is_final=False,
|
|
)
|
|
else:
|
|
variables[variable_name] = variable_value
|
|
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
|
assert message.meta is not None
|
|
assert isinstance(message.meta, dict)
|
|
# Validate that meta contains a 'file' key
|
|
if "file" not in message.meta:
|
|
raise ToolNodeError("File message is missing 'file' key in meta")
|
|
|
|
# Validate that the file is an instance of File
|
|
if not isinstance(message.meta["file"], File):
|
|
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
|
files.append(message.meta["file"])
|
|
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
|
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
|
if message.message.metadata:
|
|
icon = tool_info.get("icon", "")
|
|
dict_metadata = dict(message.message.metadata)
|
|
if dict_metadata.get("provider"):
|
|
manager = PluginInstaller()
|
|
plugins = manager.list_plugins(tenant_id)
|
|
try:
|
|
current_plugin = next(
|
|
plugin
|
|
for plugin in plugins
|
|
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
|
)
|
|
icon = current_plugin.declaration.icon
|
|
except StopIteration:
|
|
pass
|
|
icon_dark = None
|
|
try:
|
|
builtin_tool = next(
|
|
provider
|
|
for provider in BuiltinToolManageService.list_builtin_tools(
|
|
user_id,
|
|
tenant_id,
|
|
)
|
|
if provider.name == dict_metadata["provider"]
|
|
)
|
|
icon = builtin_tool.icon
|
|
icon_dark = builtin_tool.icon_dark
|
|
except StopIteration:
|
|
pass
|
|
|
|
dict_metadata["icon"] = icon
|
|
dict_metadata["icon_dark"] = icon_dark
|
|
message.message.metadata = dict_metadata
|
|
|
|
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
|
json_output: list[dict[str, Any] | list[Any]] = []
|
|
|
|
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
|
if json:
|
|
json_output.extend(json)
|
|
else:
|
|
json_output.append({"data": []})
|
|
|
|
# Send final chunk events for all streamed outputs
|
|
# Final chunk for text stream
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "text"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
|
|
# Final chunks for any streamed variables
|
|
for var_name in variables:
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, var_name],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
|
|
usage = self._extract_tool_usage(tool_runtime)
|
|
|
|
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
|
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
|
}
|
|
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
|
|
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
|
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
|
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
|
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
|
metadata=metadata,
|
|
inputs=parameters_for_log,
|
|
llm_usage=usage,
|
|
)
|
|
)
|
|
|
|
return usage
|
|
|
|
@staticmethod
|
|
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
|
# Avoid importing WorkflowTool at module import time; rely on duck typing
|
|
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
|
|
latest = getattr(tool_runtime, "latest_usage", None)
|
|
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
|
|
# for any name, so we must type-check here.
|
|
if isinstance(latest, LLMUsage):
|
|
return latest
|
|
if isinstance(latest, dict):
|
|
# Allow dict payloads from external runtimes
|
|
return LLMUsage.model_validate(latest)
|
|
# Fallback to empty usage when attribute is missing or not a valid payload
|
|
return LLMUsage.empty_usage()
|
|
|
|
@classmethod
|
|
def _extract_variable_selector_to_variable_mapping(
|
|
cls,
|
|
*,
|
|
graph_config: Mapping[str, Any],
|
|
node_id: str,
|
|
node_data: Mapping[str, Any],
|
|
) -> Mapping[str, Sequence[str]]:
|
|
"""
|
|
Extract variable selector to variable mapping
|
|
:param graph_config: graph config
|
|
:param node_id: node id
|
|
:param node_data: node data
|
|
:return:
|
|
"""
|
|
# Create typed NodeData from dict
|
|
typed_node_data = ToolNodeData.model_validate(node_data)
|
|
|
|
result = {}
|
|
for parameter_name in typed_node_data.tool_parameters:
|
|
input = typed_node_data.tool_parameters[parameter_name]
|
|
match input.type:
|
|
case "mixed":
|
|
assert isinstance(input.value, str)
|
|
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
|
for selector in selectors:
|
|
result[selector.variable] = selector.value_selector
|
|
case "variable":
|
|
selector_key = ".".join(input.value)
|
|
result[f"#{selector_key}#"] = input.value
|
|
case "constant":
|
|
pass
|
|
|
|
result = {node_id + "." + key: value for key, value in result.items()}
|
|
|
|
return result
|
|
|
|
@property
|
|
def retry(self) -> bool:
|
|
return self.node_data.retry_config.retry_enabled
|