refactor: replace BuiltinToolManageService with RagPipelineManageService for datasource management and remove unused datasource engine and related code

This commit is contained in:
Yeuoly
2025-05-16 18:42:07 +08:00
parent 8bea88c8cc
commit c5a2f43ceb
22 changed files with 390 additions and 1496 deletions

View File

@ -1,35 +1,24 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.datasource.datasource_engine import DatasourceEngine
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter
from core.datasource.errors import DatasourceInvokeError
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.datasource.entities.datasource_entities import (
DatasourceParameter,
)
from core.file import File
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError
from .exc import DatasourceNodeError, DatasourceParameterError
class DatasourceNode(BaseNode[DatasourceNodeData]):
@ -49,7 +38,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
# fetch datasource icon
datasource_info = {
"provider_type": node_data.provider_type.value,
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
@ -58,8 +46,10 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
try:
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_workflow_datasource_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=node_data.provider_id,
datasource_name=node_data.datasource_name,
tenant_id=self.tenant_id,
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
@ -74,7 +64,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return
# get parameters
datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or []
datasource_parameters = datasource_runtime.entity.parameters
parameters = self._generate_parameters(
datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
@ -91,15 +81,20 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = DatasourceEngine.generic_invoke(
datasource=datasource_runtime,
datasource_parameters=parameters,
# TODO: handle result
result = datasource_runtime._invoke_second_step(
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
datasource_parameters=parameters,
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
@ -113,20 +108,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
)
return
try:
# convert datasource messages
yield from self._transform_message(message_stream, datasource_info, parameters_for_log)
except (PluginDaemonClientSideError, DatasourceInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
)
def _generate_parameters(
self,
*,
@ -175,200 +156,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _transform_message(
self,
messages: Generator[DatasourceInvokeMessage, None, None],
datasource_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
DatasourceInvokeMessage.MessageType.BINARY_LINK,
DatasourceInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == DatasourceInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
elif message.type == DatasourceInvokeMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value
for key, value in msg_metadata.items()
if key in NodeRunMetadataKey.__members__.values()
}
json.append(message.message.json_object)
elif message.type == DatasourceInvokeMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
elif message.type == DatasourceInvokeMessage.MessageType.LOG:
assert isinstance(message.message, DatasourceInvokeMessage.LogMessage)
if message.message.metadata:
icon = datasource_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstallationManager()
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
self.user_id,
self.tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
except StopIteration:
pass
dict_metadata["icon"] = icon
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=self.id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=self.node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables},
metadata={
**agent_execution_metadata,
NodeRunMetadataKey.DATASOURCE_INFO: datasource_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -3,17 +3,15 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.base.entities import BaseNodeData
class DatasourceEntity(BaseModel):
provider_id: str
provider_type: ToolProviderType
provider_name: str # redundancy
tool_name: str
datasource_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
datasource_configurations: dict[str, Any]
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")

View File

@ -1,7 +1,8 @@
import datetime
import logging
import time
from typing import Any, cast, Mapping
from collections.abc import Mapping
from typing import Any, cast
from flask_login import current_user