mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
r2
This commit is contained in:
@ -10,14 +10,16 @@ from core.datasource.entities.datasource_entities import (
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.file import File
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.segments import ArrayAnySegment, FileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import DatasourceNodeData
|
||||
@ -59,7 +61,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
provider_id=node_data.provider_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
return NodeRunResult(
|
||||
@ -69,7 +71,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# get parameters
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
@ -105,7 +107,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
@ -116,18 +118,42 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
},
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=datasource_info.get("type", ""),
|
||||
transfer_method=datasource_info.get("transfer_method", ""),
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)])
|
||||
for key, value in datasource_info.items():
|
||||
# construct new key list
|
||||
new_key_list = ["file", key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
"file_info": file_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
||||
f"Unsupported datasource provider: {datasource_type}"
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
return NodeRunResult(
|
||||
@ -194,6 +220,26 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
|
||||
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
||||
@ -18,7 +18,7 @@ class DatasourceEntity(BaseModel):
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Optional[Union[Any, list[str]]] = None
|
||||
value: Union[Any, list[str]]
|
||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
|
||||
@ -39,15 +39,30 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||
is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
if invoke_from:
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
|
||||
else:
|
||||
is_preview = False
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
if not chunks:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||
)
|
||||
outputs = self._get_preview_output(dataset.chunk_structure, chunks)
|
||||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
@ -55,12 +70,12 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs={"result": "success"},
|
||||
outputs=outputs,
|
||||
)
|
||||
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool)
|
||||
outputs = {"result": results}
|
||||
results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks,
|
||||
variable_pool=variable_pool)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
|
||||
)
|
||||
|
||||
except KnowledgeIndexNodeError as e:
|
||||
@ -81,24 +96,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
)
|
||||
|
||||
def _invoke_knowledge_index(
|
||||
self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool
|
||||
self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any],
|
||||
variable_pool: VariablePool
|
||||
) -> Any:
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
document = db.session.query(Document).filter_by(id=document_id).first()
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
index_processor.index(dataset, document, chunks)
|
||||
@ -106,14 +115,19 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
# update document status
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"batch": batch,
|
||||
"batch": batch.value,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"created_at": document.created_at,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"display_status": document.indexing_status,
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
|
||||
Reference in New Issue
Block a user