mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
refactor: tool node decouple db (#33166)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,9 +1,6 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from dify_graph.variables.variables import ArrayAnyVariable
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
@ -36,7 +32,8 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node[ToolNodeData]):
|
||||
@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
tool_file_manager_factory: ToolFileManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not found")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
|
||||
Reference in New Issue
Block a user