mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 15:26:21 +08:00
217 lines
9.4 KiB
Python
217 lines
9.4 KiB
Python
from collections.abc import Generator, Mapping, Sequence
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
|
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
|
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
|
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
|
|
from dify_graph.nodes.base.node import Node
|
|
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
|
from dify_graph.repositories.datasource_manager_protocol import (
|
|
DatasourceManagerProtocol,
|
|
DatasourceParameter,
|
|
OnlineDriveDownloadFileParam,
|
|
)
|
|
|
|
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
|
from .entities import DatasourceNodeData
|
|
from .exc import DatasourceNodeError
|
|
|
|
if TYPE_CHECKING:
|
|
from dify_graph.entities import GraphInitParams
|
|
from dify_graph.runtime import GraphRuntimeState
|
|
|
|
|
|
class DatasourceNode(Node[DatasourceNodeData]):
|
|
"""
|
|
Datasource Node
|
|
"""
|
|
|
|
node_type = NodeType.DATASOURCE
|
|
execution_type = NodeExecutionType.ROOT
|
|
|
|
def __init__(
|
|
self,
|
|
id: str,
|
|
config: Mapping[str, Any],
|
|
graph_init_params: "GraphInitParams",
|
|
graph_runtime_state: "GraphRuntimeState",
|
|
datasource_manager: DatasourceManagerProtocol,
|
|
):
|
|
super().__init__(
|
|
id=id,
|
|
config=config,
|
|
graph_init_params=graph_init_params,
|
|
graph_runtime_state=graph_runtime_state,
|
|
)
|
|
self.datasource_manager = datasource_manager
|
|
|
|
def _run(self) -> Generator:
|
|
"""
|
|
Run the datasource node
|
|
"""
|
|
|
|
node_data = self.node_data
|
|
variable_pool = self.graph_runtime_state.variable_pool
|
|
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
|
if not datasource_type_segment:
|
|
raise DatasourceNodeError("Datasource type is not set")
|
|
datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None
|
|
datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
|
if not datasource_info_segment:
|
|
raise DatasourceNodeError("Datasource info is not set")
|
|
datasource_info_value = datasource_info_segment.value
|
|
if not isinstance(datasource_info_value, dict):
|
|
raise DatasourceNodeError("Invalid datasource info format")
|
|
datasource_info: dict[str, Any] = datasource_info_value
|
|
|
|
if datasource_type is None:
|
|
raise DatasourceNodeError("Datasource type is not set")
|
|
|
|
datasource_type = DatasourceProviderType.value_of(datasource_type)
|
|
provider_id = f"{node_data.plugin_id}/{node_data.provider_name}"
|
|
|
|
datasource_info["icon"] = self.datasource_manager.get_icon_url(
|
|
provider_id=provider_id,
|
|
datasource_name=node_data.datasource_name or "",
|
|
tenant_id=self.tenant_id,
|
|
datasource_type=datasource_type.value,
|
|
)
|
|
|
|
parameters_for_log = datasource_info
|
|
|
|
try:
|
|
match datasource_type:
|
|
case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE:
|
|
# Build typed request objects
|
|
datasource_parameters = None
|
|
if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
|
datasource_parameters = DatasourceParameter(
|
|
workspace_id=datasource_info.get("workspace_id", ""),
|
|
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
|
type=datasource_info.get("page", {}).get("type", ""),
|
|
)
|
|
|
|
online_drive_request = None
|
|
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
|
|
online_drive_request = OnlineDriveDownloadFileParam(
|
|
id=datasource_info.get("id", ""),
|
|
bucket=datasource_info.get("bucket", ""),
|
|
)
|
|
|
|
credential_id = datasource_info.get("credential_id", "")
|
|
|
|
yield from self.datasource_manager.stream_node_events(
|
|
node_id=self._node_id,
|
|
user_id=self.user_id,
|
|
datasource_name=node_data.datasource_name or "",
|
|
datasource_type=datasource_type.value,
|
|
provider_id=provider_id,
|
|
tenant_id=self.tenant_id,
|
|
provider=node_data.provider_name,
|
|
plugin_id=node_data.plugin_id,
|
|
credential_id=credential_id,
|
|
parameters_for_log=parameters_for_log,
|
|
datasource_info=datasource_info,
|
|
variable_pool=variable_pool,
|
|
datasource_param=datasource_parameters,
|
|
online_drive_request=online_drive_request,
|
|
)
|
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
|
outputs={
|
|
**datasource_info,
|
|
"datasource_type": datasource_type,
|
|
},
|
|
)
|
|
)
|
|
case DatasourceProviderType.LOCAL_FILE:
|
|
related_id = datasource_info.get("related_id")
|
|
if not related_id:
|
|
raise DatasourceNodeError("File is not exist")
|
|
|
|
file_info = self.datasource_manager.get_upload_file_by_id(
|
|
file_id=related_id, tenant_id=self.tenant_id
|
|
)
|
|
variable_pool.add([self._node_id, "file"], file_info)
|
|
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
|
outputs={
|
|
"file": file_info,
|
|
"datasource_type": datasource_type,
|
|
},
|
|
)
|
|
)
|
|
case _:
|
|
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
|
except PluginDaemonClientSideError as e:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
|
error=f"Failed to transform datasource message: {str(e)}",
|
|
error_type=type(e).__name__,
|
|
)
|
|
)
|
|
except DatasourceNodeError as e:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
|
error=f"Failed to invoke datasource: {str(e)}",
|
|
error_type=type(e).__name__,
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def _extract_variable_selector_to_variable_mapping(
|
|
cls,
|
|
*,
|
|
graph_config: Mapping[str, Any],
|
|
node_id: str,
|
|
node_data: Mapping[str, Any],
|
|
) -> Mapping[str, Sequence[str]]:
|
|
"""
|
|
Extract variable selector to variable mapping
|
|
:param graph_config: graph config
|
|
:param node_id: node id
|
|
:param node_data: node data
|
|
:return:
|
|
"""
|
|
typed_node_data = DatasourceNodeData.model_validate(node_data)
|
|
result = {}
|
|
if typed_node_data.datasource_parameters:
|
|
for parameter_name in typed_node_data.datasource_parameters:
|
|
input = typed_node_data.datasource_parameters[parameter_name]
|
|
match input.type:
|
|
case "mixed":
|
|
assert isinstance(input.value, str)
|
|
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
|
for selector in selectors:
|
|
result[selector.variable] = selector.value_selector
|
|
case "variable":
|
|
result[parameter_name] = input.value
|
|
case "constant":
|
|
pass
|
|
case None:
|
|
pass
|
|
|
|
result = {node_id + "." + key: value for key, value in result.items()}
|
|
|
|
return result
|
|
|
|
@classmethod
|
|
def version(cls) -> str:
|
|
return "1"
|