mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
Merge remote-tracking branch 'origin/main' into feat/support-agent-sandbox
# Conflicts: # api/core/file/file_manager.py # api/core/workflow/graph_engine/response_coordinator/coordinator.py # api/core/workflow/nodes/llm/node.py # api/core/workflow/nodes/tool/tool_node.py # api/pyproject.toml # web/package.json # web/pnpm-lock.yaml
This commit is contained in:
24
api/core/workflow/entities/graph_config.py
Normal file
24
api/core/workflow/entities/graph_config.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from pydantic import TypeAdapter, with_config
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigData(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
data: NodeConfigData
|
||||
|
||||
|
||||
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
|
||||
@ -5,15 +5,20 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
from libs.typing import is_str
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
|
||||
|
||||
|
||||
class NodeFactory(Protocol):
|
||||
"""
|
||||
@ -23,7 +28,7 @@ class NodeFactory(Protocol):
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
@ -63,28 +68,24 @@ class Graph:
|
||||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
|
||||
def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, dict[str, object]] = {}
|
||||
node_configs_map: dict[str, NodeConfigDict] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
continue
|
||||
|
||||
node_configs_map[node_id] = node_config
|
||||
node_configs_map[node_config["id"]] = node_config
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: Mapping[str, Mapping[str, object]],
|
||||
node_configs_map: Mapping[str, NodeConfigDict],
|
||||
edge_configs: Sequence[Mapping[str, object]],
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
@ -113,10 +114,8 @@ class Graph:
|
||||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data")
|
||||
if not is_str_dict(node_data):
|
||||
continue
|
||||
node_type = node_data.get("type")
|
||||
node_data = node_configs_map[nid]["data"]
|
||||
node_type = node_data["type"]
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if NodeType(node_type).is_start_node:
|
||||
@ -176,7 +175,7 @@ class Graph:
|
||||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_configs_map: dict[str, NodeConfigDict],
|
||||
node_factory: NodeFactory,
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
@ -303,7 +302,7 @@ class Graph:
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
node_configs = _ListNodeConfigDict.validate_python(node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
@ -46,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .ready_queue import ReadyQueue
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -90,7 +89,7 @@ class GraphEngine:
|
||||
self._graph_execution.workflow_id = workflow_id
|
||||
|
||||
# === Execution Queues ===
|
||||
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
||||
self._ready_queue = self._graph_runtime_state.ready_queue
|
||||
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
@ -25,6 +25,7 @@ from core.workflow.graph_events import (
|
||||
)
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
@ -81,7 +82,7 @@ class ResponseStreamCoordinator:
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
|
||||
def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
|
||||
"""
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
|
||||
@ -10,10 +10,10 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.runtime.graph_runtime_state import NodeProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -29,21 +29,26 @@ class ResponseSession:
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node) -> ResponseSession:
|
||||
def from_node(cls, node: NodeProtocol) -> ResponseSession:
|
||||
"""
|
||||
Create a ResponseSession from an AnswerNode or EndNode.
|
||||
Create a ResponseSession from a response-capable node.
|
||||
|
||||
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
|
||||
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
|
||||
- `id: str`
|
||||
- `get_streaming_template() -> Template`
|
||||
|
||||
Args:
|
||||
node: Must be either an AnswerNode or EndNode instance
|
||||
node: Node from the materialized workflow graph.
|
||||
|
||||
Returns:
|
||||
ResponseSession configured with the node's streaming template
|
||||
|
||||
Raises:
|
||||
TypeError: If node is not an AnswerNode or EndNode
|
||||
TypeError: If node is not a supported response node type.
|
||||
"""
|
||||
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
|
||||
raise TypeError
|
||||
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=node.get_streaming_template(),
|
||||
|
||||
@ -205,32 +205,33 @@ class AgentNode(Node[AgentNodeData]):
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
@ -387,12 +388,13 @@ class AgentNode(Node[AgentNodeData]):
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
||||
@ -115,7 +115,7 @@ class DefaultValue(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Literal, Self
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, Self] | None = None
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
|
||||
@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_type = DatasourceProviderType.value_of(datasource_type)
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
@ -268,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
case "constant":
|
||||
pass
|
||||
case None:
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
@ -306,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
match message.type:
|
||||
case (
|
||||
DatasourceMessage.MessageType.IMAGE_LINK
|
||||
| DatasourceMessage.MessageType.BINARY_LINK
|
||||
| DatasourceMessage.MessageType.IMAGE
|
||||
):
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
files.append(file)
|
||||
case DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
case DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
case DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
case DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case (
|
||||
DatasourceMessage.MessageType.BLOB_CHUNK
|
||||
| DatasourceMessage.MessageType.LOG
|
||||
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
|
||||
):
|
||||
pass
|
||||
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
|
||||
@ -2,7 +2,7 @@ import base64
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlencode, urlparse
|
||||
@ -11,9 +11,9 @@ import httpx
|
||||
from json_repair import repair_json
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file.file_manager import file_manager as default_file_manager
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -79,8 +79,8 @@ class Executor:
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
http_client: HttpClientProtocol | None = None,
|
||||
file_manager: FileManagerProtocol | None = None,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
@ -107,8 +107,8 @@ class Executor:
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
self._http_client = http_client
|
||||
self._file_manager = file_manager
|
||||
self._http_client = http_client or ssrf_proxy
|
||||
self._file_manager = file_manager or default_file_manager
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
@ -336,7 +336,7 @@ class Executor:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
_METHOD_MAP = {
|
||||
_METHOD_MAP: dict[str, Callable[..., httpx.Response]] = {
|
||||
"get": self._http_client.get,
|
||||
"head": self._http_client.head,
|
||||
"post": self._http_client.post,
|
||||
@ -348,7 +348,7 @@ class Executor:
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
request_args: dict[str, Any] = {
|
||||
"data": self.data,
|
||||
"files": self.files,
|
||||
"json": self.json,
|
||||
@ -361,14 +361,13 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](
|
||||
response = _METHOD_MAP[method_lc](
|
||||
url=self.url,
|
||||
**request_args,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response
|
||||
|
||||
def invoke(self) -> Response:
|
||||
|
||||
@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.file.file_manager import file_manager as default_file_manager
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
http_client: HttpClientProtocol | None = None,
|
||||
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
file_manager: FileManagerProtocol | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._http_client = http_client
|
||||
self._http_client = http_client or ssrf_proxy
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
self._file_manager = file_manager
|
||||
self._file_manager = file_manager or default_file_manager
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
||||
@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
return outputs
|
||||
|
||||
# Check if all non-None outputs are lists
|
||||
non_none_outputs = [output for output in outputs if output is not None]
|
||||
non_none_outputs: list[object] = [output for output in outputs if output is not None]
|
||||
if not non_none_outputs:
|
||||
return outputs
|
||||
|
||||
|
||||
@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
# Try to get document language if document_id is available
|
||||
doc_language = None
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document and document.doc_language:
|
||||
doc_language = document.doc_language
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
doc_language=doc_language,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
doc_language: str | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
doc_language: Optional document language to ensure summary is generated in the correct language
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
|
||||
@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
case "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
case _:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None, None, usage
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
match node_data.metadata_filtering_mode:
|
||||
case "disabled":
|
||||
return None, None, usage
|
||||
case "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
case "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
case _:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
|
||||
@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
||||
case "name":
|
||||
return lambda x: x.filename or ""
|
||||
case "type":
|
||||
return lambda x: x.type
|
||||
return lambda x: str(x.type)
|
||||
case "extension":
|
||||
return lambda x: x.extension or ""
|
||||
case "mime_type":
|
||||
return lambda x: x.mime_type or ""
|
||||
case "transfer_method":
|
||||
return lambda x: x.transfer_method
|
||||
return lambda x: str(x.transfer_method)
|
||||
case "url":
|
||||
return lambda x: x.remote_url or ""
|
||||
case "related_id":
|
||||
@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
elif key == "size" and isinstance(value, str):
|
||||
extract_func = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
|
||||
extract_number = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
@ -1288,18 +1288,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
# For issue #11247 - Check if prompt content is a string or a list
|
||||
prompt_content_type = type(prompt_content)
|
||||
if prompt_content_type == str:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
@ -1309,13 +1307,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# Add current query to the prompt message
|
||||
if sys_query:
|
||||
if prompt_content_type == str:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
@ -1481,13 +1478,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if typed_node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for item in prompt_template:
|
||||
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
if prompt_template.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
else:
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
break
|
||||
else:
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol):
|
||||
@property
|
||||
def request_error(self) -> type[Exception]: ...
|
||||
|
||||
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
|
||||
class FileManagerProtocol(Protocol):
|
||||
|
||||
@ -513,25 +513,26 @@ class ToolNode(Node[ToolNodeData]):
|
||||
result: dict[str, Sequence[str]] = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
if isinstance(input.value, list):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "nested_node":
|
||||
# Nested node type: extract variable selector from nested_node_config
|
||||
# The full selector is extractor_node_id + output_selector
|
||||
if input.nested_node_config is not None:
|
||||
config = input.nested_node_config
|
||||
full_selector = [config.extractor_node_id] + list(config.output_selector)
|
||||
selector_key = ".".join(full_selector)
|
||||
result[f"#{selector_key}#"] = full_selector
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
if isinstance(input.value, list):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
case "nested_node":
|
||||
# Nested node type: extract variable selector from nested_node_config
|
||||
# The full selector is extractor_node_id + output_selector
|
||||
if input.nested_node_config is not None:
|
||||
config = input.nested_node_config
|
||||
full_selector = [config.extractor_node_id] + list(config.output_selector)
|
||||
selector_key = ".".join(full_selector)
|
||||
result[f"#{selector_key}#"] = full_selector
|
||||
case "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
||||
@ -6,13 +6,14 @@ import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, ClassVar, Protocol
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
|
||||
@ -104,14 +105,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class NodeProtocol(Protocol):
|
||||
"""Structural interface for graph nodes."""
|
||||
|
||||
id: str
|
||||
state: NodeState
|
||||
execution_type: NodeExecutionType
|
||||
node_type: ClassVar[NodeType]
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ...
|
||||
|
||||
|
||||
class EdgeProtocol(Protocol):
|
||||
id: str
|
||||
state: NodeState
|
||||
tail: str
|
||||
head: str
|
||||
source_handle: str
|
||||
|
||||
|
||||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: Mapping[str, object]
|
||||
edges: Mapping[str, object]
|
||||
root_node: object
|
||||
nodes: Mapping[str, NodeProtocol]
|
||||
edges: Mapping[str, EdgeProtocol]
|
||||
root_node: NodeProtocol
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
||||
@ -146,11 +146,11 @@ class WorkflowEntry:
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
node_config = dict(workflow.get_node_config_by_id(node_id))
|
||||
node_config_data = node_config.get("data", {})
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config_data = node_config["data"]
|
||||
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_type = NodeType(node_config_data["type"])
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
Reference in New Issue
Block a user