mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
r2
This commit is contained in:
@ -10,14 +10,16 @@ from core.datasource.entities.datasource_entities import (
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.file import File
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.segments import ArrayAnySegment, FileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import DatasourceNodeData
|
||||
@ -59,7 +61,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
provider_id=node_data.provider_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
return NodeRunResult(
|
||||
@ -69,7 +71,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# get parameters
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
@ -105,7 +107,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
@ -116,18 +118,42 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
},
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=datasource_info.get("type", ""),
|
||||
transfer_method=datasource_info.get("transfer_method", ""),
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)])
|
||||
for key, value in datasource_info.items():
|
||||
# construct new key list
|
||||
new_key_list = ["file", key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
"file_info": file_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
||||
f"Unsupported datasource provider: {datasource_type}"
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
return NodeRunResult(
|
||||
@ -194,6 +220,26 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
|
||||
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
||||
@ -18,7 +18,7 @@ class DatasourceEntity(BaseModel):
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Optional[Union[Any, list[str]]] = None
|
||||
value: Union[Any, list[str]]
|
||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
|
||||
Reference in New Issue
Block a user