Merge remote-tracking branch 'origin/main' into feat/support-agent-sandbox

# Conflicts:
#	api/core/file/file_manager.py
#	api/core/workflow/graph_engine/response_coordinator/coordinator.py
#	api/core/workflow/nodes/llm/node.py
#	api/core/workflow/nodes/tool/tool_node.py
#	api/pyproject.toml
#	web/package.json
#	web/pnpm-lock.yaml
This commit is contained in:
Harry
2026-02-04 13:15:49 +08:00
131 changed files with 7256 additions and 3245 deletions

View File

@ -0,0 +1,24 @@
from __future__ import annotations
import sys
from pydantic import TypeAdapter, with_config
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
@with_config(extra="allow")
class NodeConfigData(TypedDict):
type: str
@with_config(extra="allow")
class NodeConfigDict(TypedDict):
id: str
data: NodeConfigData
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)

View File

@ -5,15 +5,20 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from pydantic import TypeAdapter
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from libs.typing import is_str
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
class NodeFactory(Protocol):
"""
@ -23,7 +28,7 @@ class NodeFactory(Protocol):
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: dict[str, object]) -> Node:
def create_node(self, node_config: NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data.
@ -63,28 +68,24 @@ class Graph:
self.root_node = root_node
@classmethod
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
node_configs_map: dict[str, dict[str, object]] = {}
node_configs_map: dict[str, NodeConfigDict] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id or not isinstance(node_id, str):
continue
node_configs_map[node_id] = node_config
node_configs_map[node_config["id"]] = node_config
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, Mapping[str, object]],
node_configs_map: Mapping[str, NodeConfigDict],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
@ -113,10 +114,8 @@ class Graph:
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid].get("data")
if not is_str_dict(node_data):
continue
node_type = node_data.get("type")
node_data = node_configs_map[nid]["data"]
node_type = node_data["type"]
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
@ -176,7 +175,7 @@ class Graph:
@classmethod
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_configs_map: dict[str, NodeConfigDict],
node_factory: NodeFactory,
) -> dict[str, Node]:
"""
@ -303,7 +302,7 @@ class Graph:
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")

View File

@ -46,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .ready_queue import ReadyQueue
from .worker_management import WorkerPool
if TYPE_CHECKING:
@ -90,7 +89,7 @@ class GraphEngine:
self._graph_execution.workflow_id = workflow_id
# === Execution Queues ===
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
self._ready_queue = self._graph_runtime_state.ready_queue
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()

View File

@ -25,6 +25,7 @@ from core.workflow.graph_events import (
)
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .path import Path
from .session import ResponseSession
@ -81,7 +82,7 @@ class ResponseStreamCoordinator:
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
"""
Initialize coordinator with variable pool.

View File

@ -10,10 +10,10 @@ from __future__ import annotations
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.runtime.graph_runtime_state import NodeProtocol
@dataclass
@ -29,21 +29,26 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> ResponseSession:
def from_node(cls, node: NodeProtocol) -> ResponseSession:
"""
Create a ResponseSession from an AnswerNode or EndNode.
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
- `id: str`
- `get_streaming_template() -> Template`
Args:
node: Must be either an AnswerNode or EndNode instance
node: Node from the materialized workflow graph.
Returns:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node is not an AnswerNode or EndNode
TypeError: If node is not a supported response node type.
"""
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
return cls(
node_id=node.id,
template=node.get_streaming_template(),

View File

@ -205,32 +205,33 @@ class AgentNode(Node[AgentNodeData]):
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
else:
raise AgentInputTypeError(agent_input.type)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
@ -387,12 +388,13 @@ class AgentNode(Node[AgentNodeData]):
result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}

View File

@ -115,7 +115,7 @@ class DefaultValue(BaseModel):
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators = {
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Self
from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, Self] | None = None
children: dict[str, "CodeNodeData.Output"] | None = None
class Dependency(BaseModel):
name: str

View File

@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]):
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = DatasourceProviderType.value_of(datasource_type)
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
datasource_type=datasource_type,
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
@ -268,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
case "constant":
pass
case None:
pass
result = {node_id + "." + key: value for key, value in result.items()}
@ -306,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceMessage.TextMessage)
match message.type:
case (
DatasourceMessage.MessageType.IMAGE_LINK
| DatasourceMessage.MessageType.BINARY_LINK
| DatasourceMessage.MessageType.IMAGE
):
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
files.append(file)
case DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
case DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
case DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
case DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
case DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case (
DatasourceMessage.MessageType.BLOB_CHUNK
| DatasourceMessage.MessageType.LOG
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
):
pass
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],

View File

@ -2,7 +2,7 @@ import base64
import json
import secrets
import string
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from copy import deepcopy
from typing import Any, Literal
from urllib.parse import urlencode, urlparse
@ -11,9 +11,9 @@ import httpx
from json_repair import repair_json
from configs import dify_config
from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
@ -79,8 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
http_client: HttpClientProtocol = ssrf_proxy,
file_manager: FileManagerProtocol = file_manager,
http_client: HttpClientProtocol | None = None,
file_manager: FileManagerProtocol | None = None,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -107,8 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
self._http_client = http_client
self._file_manager = file_manager
self._http_client = http_client or ssrf_proxy
self._file_manager = file_manager or default_file_manager
# init template
self.variable_pool = variable_pool
@ -336,7 +336,7 @@ class Executor:
"""
do http request depending on api bundle
"""
_METHOD_MAP = {
_METHOD_MAP: dict[str, Callable[..., httpx.Response]] = {
"get": self._http_client.get,
"head": self._http_client.head,
"post": self._http_client.post,
@ -348,7 +348,7 @@ class Executor:
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
request_args: dict[str, Any] = {
"data": self.data,
"files": self.files,
"json": self.json,
@ -361,14 +361,13 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response: httpx.Response = _METHOD_MAP[method_lc](
response = _METHOD_MAP[method_lc](
url=self.url,
**request_args,
max_retries=self.max_retries,
)
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response
def invoke(self) -> Response:

View File

@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.file import File, FileTransferMethod
from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
http_client: HttpClientProtocol = ssrf_proxy,
http_client: HttpClientProtocol | None = None,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol = file_manager,
file_manager: FileManagerProtocol | None = None,
) -> None:
super().__init__(
id=id,
@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._http_client = http_client
self._http_client = http_client or ssrf_proxy
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager
self._file_manager = file_manager or default_file_manager
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:

View File

@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return outputs
# Check if all non-None outputs are lists
non_none_outputs = [output for output in outputs if output is not None]
non_none_outputs: list[object] = [output for output in outputs if output is not None]
if not non_none_outputs:
return outputs

View File

@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
# Try to get document language if document_id is available
doc_language = None
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document and document.doc_language:
doc_language = document.doc_language
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure,
chunks,
dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting,
doc_language=doc_language,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None,
doc_language: str | None = None,
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
doc_language: Optional document language to ensure summary is generated in the correct language
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary
@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary

View File

@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
case "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
else:
reranking_model = None
weights = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
case _:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id,
tenant_id=self.tenant_id,
@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None, usage
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
match node_data.metadata_filtering_mode:
case "disabled":
return None, None, usage
case "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
case "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
case _:
raise ValueError("Invalid metadata filtering mode")
if filters:
if (
node_data.metadata_filtering_conditions

View File

@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "name":
return lambda x: x.filename or ""
case "type":
return lambda x: x.type
return lambda x: str(x.type)
case "extension":
return lambda x: x.extension or ""
case "mime_type":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: x.transfer_method
return lambda x: str(x.transfer_method)
case "url":
return lambda x: x.remote_url or ""
case "related_id":
@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
extract_number = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
else:
raise InvalidKeyError(f"Invalid key: {key}")

View File

@ -1288,18 +1288,16 @@ class LLMNode(Node[LLMNodeData]):
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
@ -1309,13 +1307,12 @@ class LLMNode(Node[LLMNodeData]):
# Add current query to the prompt message
if sys_query:
if prompt_content_type == str:
if isinstance(prompt_content, str):
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
@ -1481,13 +1478,14 @@ class LLMNode(Node[LLMNodeData]):
if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
for item in prompt_template:
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type == "jinja2":
enable_jinja = True
else:
for prompt in prompt_template:
if prompt.edition_type == "jinja2":
enable_jinja = True
break
else:
enable_jinja = True
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:

View File

@ -1,4 +1,4 @@
from typing import Protocol
from typing import Any, Protocol
import httpx
@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol):
@property
def request_error(self) -> type[Exception]: ...
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
class FileManagerProtocol(Protocol):

View File

@ -513,25 +513,26 @@ class ToolNode(Node[ToolNodeData]):
result: dict[str, Sequence[str]] = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
if isinstance(input.value, list):
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "nested_node":
# Nested node type: extract variable selector from nested_node_config
# The full selector is extractor_node_id + output_selector
if input.nested_node_config is not None:
config = input.nested_node_config
full_selector = [config.extractor_node_id] + list(config.output_selector)
selector_key = ".".join(full_selector)
result[f"#{selector_key}#"] = full_selector
elif input.type == "constant":
pass
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
if isinstance(input.value, list):
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
case "nested_node":
# Nested node type: extract variable selector from nested_node_config
# The full selector is extractor_node_id + output_selector
if input.nested_node_config is not None:
config = input.nested_node_config
full_selector = [config.extractor_node_id] + list(config.output_selector)
selector_key = ".".join(full_selector)
result[f"#{selector_key}#"] = full_selector
case "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}

View File

@ -6,13 +6,14 @@ import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol
from typing import Any, ClassVar, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.sandbox.sandbox import Sandbox
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.runtime.variable_pool import VariablePool
@ -104,14 +105,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
...
class NodeProtocol(Protocol):
"""Structural interface for graph nodes."""
id: str
state: NodeState
execution_type: NodeExecutionType
node_type: ClassVar[NodeType]
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ...
class EdgeProtocol(Protocol):
id: str
state: NodeState
tail: str
head: str
source_handle: str
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: Mapping[str, object]
edges: Mapping[str, object]
root_node: object
nodes: Mapping[str, NodeProtocol]
edges: Mapping[str, EdgeProtocol]
root_node: NodeProtocol
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
@dataclass(slots=True)

View File

@ -146,11 +146,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
node_config = workflow.get_node_config_by_id(node_id)
node_config_data = node_config["data"]
# Get node type
node_type = NodeType(node_config_data.get("type"))
node_type = NodeType(node_config_data["type"])
# init graph init params and runtime state
graph_init_params = GraphInitParams(