This commit is contained in:
jyong
2025-05-28 17:56:04 +08:00
parent 5fc2bc58a9
commit 7f59ffe7af
32 changed files with 680 additions and 202 deletions

View File

@ -10,14 +10,16 @@ from core.datasource.entities.datasource_entities import (
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.file import File
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.segments import ArrayAnySegment, FileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import UploadFile
from models.workflow import WorkflowNodeExecutionStatus
from .entities import DatasourceNodeData
@ -59,7 +61,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
provider_id=node_data.provider_id,
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
except DatasourceNodeError as e:
return NodeRunResult(
@ -69,7 +71,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
# get parameters
datasource_parameters = datasource_runtime.entity.parameters
@ -105,7 +107,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
"datasource_type": datasource_type,
},
)
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
case DatasourceProviderType.WEBSITE_CRAWL:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
@ -116,18 +118,42 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
},
)
case DatasourceProviderType.LOCAL_FILE:
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=datasource_info.get("type", ""),
transfer_method=datasource_info.get("transfer_method", ""),
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
)
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)])
for key, value in datasource_info.items():
# construct new key list
new_key_list = ["file", key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": datasource_info,
"datasource_type": datasource_runtime.datasource_provider_type,
"file_info": file_info,
"datasource_type": datasource_type,
},
)
case _:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
f"Unsupported datasource provider: {datasource_type}"
)
except PluginDaemonClientSideError as e:
return NodeRunResult(
@ -194,6 +220,26 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_pool.add([node_id] + variable_key_list, variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
)
@classmethod
def _extract_variable_selector_to_variable_mapping(

View File

@ -18,7 +18,7 @@ class DatasourceEntity(BaseModel):
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Optional[Union[Any, list[str]]] = None
value: Union[Any, list[str]]
type: Optional[Literal["mixed", "variable", "constant"]] = None
@field_validator("type", mode="before")