This commit is contained in:
jyong
2025-05-25 23:09:01 +08:00
parent 0f10852b6b
commit ec1c4efca9
12 changed files with 147 additions and 110 deletions

View File

@ -1,3 +1,3 @@
from .tool_node import ToolNode
from .datasource_node import DatasourceNode
__all__ = ["DatasourceNode"]

View File

@ -40,14 +40,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
node_data = cast(DatasourceNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
if not datasource_type:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = datasource_type.value
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if not datasource_info:
raise DatasourceNodeError("Datasource info is not set")
datasource_info = datasource_info.value
# get datasource runtime
try:
from core.datasource.datasource_manager import DatasourceManager
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
@ -84,47 +89,55 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
)
try:
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPageContentResponse = (
datasource_runtime._get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
provider_type=datasource_runtime.datasource_provider_type(),
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPageContentResponse = (
datasource_runtime._get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
provider_type=datasource_type,
)
)
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_type,
},
)
)
)
elif (
datasource_runtime.datasource_provider_type in (
DatasourceProviderType.WEBSITE_CRAWL,
DatasourceProviderType.LOCAL_FILE,
)
):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"website": datasource_info,
"datasource_type": datasource_runtime.datasource_provider_type,
},
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"website": datasource_info,
"datasource_type": datasource_type,
},
)
)
case DatasourceProviderType.LOCAL_FILE:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": datasource_info,
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
)
case _:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
)
)
else:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -170,23 +183,24 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
datasource_input = node_data.datasource_parameters[parameter_name]
if datasource_input.type == "variable":
variable = variable_pool.get(datasource_input.value)
if variable is None:
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value
elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
datasource_input = node_data.datasource_parameters[parameter_name]
if datasource_input.type == "variable":
variable = variable_pool.get(datasource_input.value)
if variable is None:
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value
elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
return result

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Union
from typing import Any, Literal, Union, Optional
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
@ -9,30 +9,17 @@ from core.workflow.nodes.base.entities import BaseNodeData
class DatasourceEntity(BaseModel):
provider_id: str
provider_name: str # redundancy
datasource_name: str
tool_label: str # redundancy
datasource_configurations: dict[str, Any]
provider_type: str
datasource_name: Optional[str] = "local_file"
datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
value: Optional[Union[Any, list[str]]] = None
type: Optional[Literal["mixed", "variable", "constant"]] = None
@field_validator("type", mode="before")
@classmethod
@ -51,4 +38,4 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
raise ValueError("value must be a string, int, float, or bool")
return typ
datasource_parameters: dict[str, DatasourceInput]
datasource_parameters: dict[str, DatasourceInput] | None = None

View File

@ -19,6 +19,7 @@ from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
)
from ..base import BaseNode
logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ default_retrieval_model = {
}
class KnowledgeIndexNode(LLMNode):
class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
_node_data_cls = KnowledgeIndexNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_INDEX
@ -44,7 +45,7 @@ class KnowledgeIndexNode(LLMNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not object type.",
error="Index chunk variable is not object type.",
)
chunks = variable.value
variables = {"chunks": chunks}

View File

@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end import EndNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
@ -119,4 +121,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
LATEST_VERSION: AgentNode,
"1": AgentNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
}