mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
r2
This commit is contained in:
@ -1,3 +1,3 @@
|
||||
from .tool_node import ToolNode
|
||||
from .datasource_node import DatasourceNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
||||
|
||||
@ -40,14 +40,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
|
||||
node_data = cast(DatasourceNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
if not datasource_type:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = datasource_type.value
|
||||
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if not datasource_info:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info = datasource_info.value
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
|
||||
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
@ -84,47 +89,55 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
)
|
||||
|
||||
try:
|
||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: GetOnlineDocumentPageContentResponse = (
|
||||
datasource_runtime._get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: GetOnlineDocumentPageContentResponse = (
|
||||
datasource_runtime._get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"online_document": online_document_result.result.model_dump(),
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"online_document": online_document_result.result.model_dump(),
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
elif (
|
||||
datasource_runtime.datasource_provider_type in (
|
||||
DatasourceProviderType.WEBSITE_CRAWL,
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
)
|
||||
):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"website": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"website": datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@ -170,23 +183,24 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, Literal, Union, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
@ -9,30 +9,17 @@ from core.workflow.nodes.base.entities import BaseNodeData
|
||||
class DatasourceEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_name: str # redundancy
|
||||
datasource_name: str
|
||||
tool_label: str # redundancy
|
||||
datasource_configurations: dict[str, Any]
|
||||
provider_type: str
|
||||
datasource_name: Optional[str] = "local_file"
|
||||
datasource_configurations: dict[str, Any] | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: Optional[Union[Any, list[str]]] = None
|
||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
@ -51,4 +38,4 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput]
|
||||
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||
|
||||
@ -19,6 +19,7 @@ from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
KnowledgeIndexNodeError,
|
||||
)
|
||||
from ..base import BaseNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,7 +32,7 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeIndexNode(LLMNode):
|
||||
class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
_node_data_cls = KnowledgeIndexNodeData # type: ignore
|
||||
_node_type = NodeType.KNOWLEDGE_INDEX
|
||||
|
||||
@ -44,7 +45,7 @@ class KnowledgeIndexNode(LLMNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error="Query variable is not object type.",
|
||||
error="Index chunk variable is not object type.",
|
||||
)
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
|
||||
@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end import EndNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
@ -119,4 +121,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
LATEST_VERSION: AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user