refactor: tool node decouple db (#33166)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2026-03-10 01:47:15 +08:00
committed by GitHub
parent a480e9beb1
commit b9d05d3456
9 changed files with 81 additions and 22 deletions

View File

@ -1,9 +1,6 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
from dify_graph.variables.variables import ArrayAnyVariable
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
@ -36,7 +32,8 @@ from .exc import (
)
if TYPE_CHECKING:
from dify_graph.runtime import VariablePool
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState, VariablePool
class ToolNode(Node[ToolNodeData]):
@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]):
node_type = NodeType.TOOL
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
tool_file_manager_factory: ToolFileManagerProtocol,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._tool_file_manager_factory = tool_file_manager_factory
@classmethod
def version(cls) -> str:
return "1"
@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]):
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not found")
mapping = {
"tool_file_id": tool_file_id,
@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]):
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,