Merge commit 'fb41b215' into sandboxed-agent-rebase

Made-with: Cursor

# Conflicts:
#	.devcontainer/post_create_command.sh
#	api/commands.py
#	api/core/agent/cot_agent_runner.py
#	api/core/agent/fc_agent_runner.py
#	api/core/app/apps/workflow_app_runner.py
#	api/core/app/entities/queue_entities.py
#	api/core/app/entities/task_entities.py
#	api/core/workflow/workflow_entry.py
#	api/dify_graph/enums.py
#	api/dify_graph/graph/graph.py
#	api/dify_graph/graph_events/node.py
#	api/dify_graph/model_runtime/entities/message_entities.py
#	api/dify_graph/node_events/node.py
#	api/dify_graph/nodes/agent/agent_node.py
#	api/dify_graph/nodes/base/__init__.py
#	api/dify_graph/nodes/base/entities.py
#	api/dify_graph/nodes/base/node.py
#	api/dify_graph/nodes/llm/entities.py
#	api/dify_graph/nodes/llm/node.py
#	api/dify_graph/nodes/tool/tool_node.py
#	api/pyproject.toml
#	api/uv.lock
#	web/app/components/base/avatar/__tests__/index.spec.tsx
#	web/app/components/base/avatar/index.tsx
#	web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx
#	web/app/components/base/file-uploader/file-from-link-or-local/index.tsx
#	web/app/components/base/prompt-editor/index.tsx
#	web/app/components/datasets/metadata/edit-metadata-batch/modal.tsx
#	web/app/components/header/account-dropdown/index.spec.tsx
#	web/app/components/share/text-generation/index.tsx
#	web/app/components/workflow/block-selector/tool/action-item.tsx
#	web/app/components/workflow/block-selector/trigger-plugin/action-item.tsx
#	web/app/components/workflow/hooks/use-edges-interactions.ts
#	web/app/components/workflow/hooks/use-nodes-interactions.ts
#	web/app/components/workflow/index.tsx
#	web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx
#	web/app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx
#	web/app/components/workflow/nodes/human-input/components/delivery-method/recipient/email-item.tsx
#	web/app/components/workflow/nodes/loop/use-interactions.ts
#	web/contract/router.ts
#	web/env.ts
#	web/eslint-suppressions.json
#	web/package.json
#	web/pnpm-lock.yaml
This commit is contained in:
Novice
2026-03-23 10:52:06 +08:00
1395 changed files with 167201 additions and 73658 deletions

View File

@ -113,7 +113,7 @@ The codebase enforces strict layering via import-linter:
1. Create node class in `nodes/<node_type>/`
1. Inherit from `BaseNode` or appropriate base class
1. Implement `_run()` method
1. Register in `nodes/node_mapping.py`
1. Ensure the node module is importable under `nodes/<node_type>/`
1. Add tests in `tests/unit_tests/dify_graph/nodes/`
### Implementing a Custom Layer

View File

@ -1,4 +1,3 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
from .workflow_execution import WorkflowExecution
@ -6,7 +5,6 @@ from .workflow_node_execution import WorkflowNodeExecution
from .workflow_start_reason import WorkflowStartReason
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"ToolCall",
"ToolCallResult",

View File

@ -1,8 +0,0 @@
from pydantic import BaseModel
class AgentNodeStrategyInit(BaseModel):
"""Agent node strategy initialization data."""
name: str
icon: str | None = None

View File

@ -0,0 +1,184 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator
from dify_graph.entities.exc import DefaultValueTypeError
from dify_graph.enums import ErrorStrategy, NodeType
# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`.
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
# `type` therefore accepts downstream string node kinds; unknown node implementations
# are rejected later when the node factory resolves the node registry.
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
# and persisted templates/workflows also carry undeclared compatibility keys such as
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
# here until graph parsing becomes discriminated by node type or those legacy payloads
# are normalized.
model_config = ConfigDict(extra="allow")
type: NodeType
title: str = ""
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = Field(default_factory=RetryConfig)
parent_node_id: str | None = None
@property
def is_extractor_node(self) -> bool:
return self.parent_node_id is not None
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
def __getitem__(self, key: str) -> Any:
"""
Dict-style access without calling model_dump() on every lookup.
Prefer using model fields and Pydantic's extra storage.
"""
# First, check declared model fields
if key in self.__class__.model_fields:
return getattr(self, key)
# Then, check undeclared compatibility fields stored in Pydantic's extra dict.
extras = getattr(self, "__pydantic_extra__", None)
if extras is None:
extras = getattr(self, "model_extra", None)
if extras is not None and key in extras:
return extras[key]
raise KeyError(key)
def get(self, key: str, default: Any = None) -> Any:
"""
Dict-style .get() without calling model_dump() on every lookup.
"""
if key in self.__class__.model_fields:
return getattr(self, key)
extras = getattr(self, "__pydantic_extra__", None)
if extras is None:
extras = getattr(self, "model_extra", None)
if extras is not None and key in extras:
return extras.get(key, default)
return default

View File

@ -4,21 +4,20 @@ import sys
from pydantic import TypeAdapter, with_config
from dify_graph.entities.base_node_data import BaseNodeData
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
@with_config(extra="allow")
class NodeConfigData(TypedDict):
type: str
@with_config(extra="allow")
class NodeConfigDict(TypedDict):
id: str
data: NodeConfigData
# This is the permissive raw graph boundary. Node factories re-validate `data`
# with the concrete `NodeData` subtype after resolving the node implementation.
data: BaseNodeData
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)

View File

@ -48,7 +48,7 @@ class WorkflowNodeExecution(BaseModel):
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: str | None = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
node_type: NodeType # Type of node (e.g., start, llm, downstream response node)
title: str # Display title of the node
# Execution data

View File

@ -1,4 +1,5 @@
from enum import StrEnum
from typing import ClassVar, TypeAlias
class NodeState(StrEnum):
@ -33,59 +34,85 @@ class SystemVariableKey(StrEnum):
INVOKE_FROM = "invoke_from"
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
TRIGGER_WEBHOOK = "trigger-webhook"
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
COMMAND = "command"
FILE_UPLOAD = "file-upload"
GROUP = "group"
NodeType: TypeAlias = str
@property
def is_trigger_node(self) -> bool:
"""Check if this node type is a trigger node."""
return self in [
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
@property
def is_start_node(self) -> bool:
"""Check if this node type can serve as a workflow entry point."""
return self in [
NodeType.START,
NodeType.DATASOURCE,
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
class BuiltinNodeTypes:
"""Built-in node type string constants.
`node_type` values are plain strings throughout the graph runtime. This namespace
only exposes the built-in values shipped by `dify_graph`; downstream packages can
use additional strings without extending this class.
"""
START: ClassVar[NodeType] = "start"
END: ClassVar[NodeType] = "end"
ANSWER: ClassVar[NodeType] = "answer"
LLM: ClassVar[NodeType] = "llm"
KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval"
IF_ELSE: ClassVar[NodeType] = "if-else"
CODE: ClassVar[NodeType] = "code"
TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform"
QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier"
HTTP_REQUEST: ClassVar[NodeType] = "http-request"
TOOL: ClassVar[NodeType] = "tool"
DATASOURCE: ClassVar[NodeType] = "datasource"
VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner"
LOOP: ClassVar[NodeType] = "loop"
LOOP_START: ClassVar[NodeType] = "loop-start"
LOOP_END: ClassVar[NodeType] = "loop-end"
ITERATION: ClassVar[NodeType] = "iteration"
ITERATION_START: ClassVar[NodeType] = "iteration-start"
PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor"
VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner"
DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor"
LIST_OPERATOR: ClassVar[NodeType] = "list-operator"
AGENT: ClassVar[NodeType] = "agent"
KNOWLEDGE_INDEX: ClassVar[NodeType] = "knowledge-index"
TRIGGER_WEBHOOK: ClassVar[NodeType] = "trigger-webhook"
TRIGGER_SCHEDULE: ClassVar[NodeType] = "trigger-schedule"
TRIGGER_PLUGIN: ClassVar[NodeType] = "trigger-plugin"
HUMAN_INPUT: ClassVar[NodeType] = "human-input"
COMMAND: ClassVar[NodeType] = "command"
FILE_UPLOAD: ClassVar[NodeType] = "file-upload"
GROUP: ClassVar[NodeType] = "group"
BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = (
BuiltinNodeTypes.START,
BuiltinNodeTypes.END,
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
BuiltinNodeTypes.IF_ELSE,
BuiltinNodeTypes.CODE,
BuiltinNodeTypes.TEMPLATE_TRANSFORM,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
BuiltinNodeTypes.HTTP_REQUEST,
BuiltinNodeTypes.TOOL,
BuiltinNodeTypes.DATASOURCE,
BuiltinNodeTypes.VARIABLE_AGGREGATOR,
BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR,
BuiltinNodeTypes.LOOP,
BuiltinNodeTypes.LOOP_START,
BuiltinNodeTypes.LOOP_END,
BuiltinNodeTypes.ITERATION,
BuiltinNodeTypes.ITERATION_START,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
BuiltinNodeTypes.VARIABLE_ASSIGNER,
BuiltinNodeTypes.DOCUMENT_EXTRACTOR,
BuiltinNodeTypes.LIST_OPERATOR,
BuiltinNodeTypes.AGENT,
BuiltinNodeTypes.KNOWLEDGE_INDEX,
BuiltinNodeTypes.TRIGGER_WEBHOOK,
BuiltinNodeTypes.TRIGGER_SCHEDULE,
BuiltinNodeTypes.TRIGGER_PLUGIN,
BuiltinNodeTypes.HUMAN_INPUT,
BuiltinNodeTypes.COMMAND,
BuiltinNodeTypes.FILE_UPLOAD,
BuiltinNodeTypes.GROUP,
)
class NodeExecutionType(StrEnum):
@ -239,7 +266,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
TRIGGER_INFO = "trigger_info"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, model_validator
@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel):
number_limits: int = 0
class ToolFile(BaseModel):
id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file")
user_id: UUID = Field(..., description="ID of the user who owns this file")
tenant_id: UUID = Field(..., description="ID of the tenant/organization")
conversation_id: UUID | None = Field(None, description="ID of the associated conversation")
file_key: str = Field(..., max_length=255, description="Storage key for the file")
mimetype: str = Field(..., max_length=255, description="MIME type of the file")
original_url: str | None = Field(
None, max_length=2048, description="Original URL if file was fetched from external source"
)
name: str = Field(default="", max_length=255, description="Display name of the file")
size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)")
class Config:
from_attributes = True # Enable ORM mode for SQLAlchemy compatibility
populate_by_name = True
class File(BaseModel):
# NOTE: dify_model_identity is a special identifier used to distinguish between
# new and old data formats during serialization and deserialization.

View File

@ -8,7 +8,7 @@ from typing import Protocol, cast, final
from pydantic import TypeAdapter
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
from dify_graph.nodes.base.node import Node
from libs.typing import is_str
@ -34,7 +34,8 @@ class NodeFactory(Protocol):
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
:raises ValueError: if node type is unknown or no implementation exists for the resolved version
:raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation
"""
...
@ -82,53 +83,6 @@ class Graph:
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, NodeConfigDict],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
"""
Find the root node ID if not specified.
:param node_configs_map: mapping of node ID to node config
:param edge_configs: list of edge configurations
:param root_node_id: explicitly specified root node ID
:return: determined root node ID
"""
if root_node_id:
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
return root_node_id
# Find nodes with no incoming edges
nodes_with_incoming: set[str] = set()
for edge_config in edge_configs:
target = edge_config.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid]["data"]
node_type = node_data["type"]
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
start_node_id = nid
break
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
if not root_node_id:
raise ValueError("Unable to determine root node ID")
return root_node_id
@classmethod
def _build_edges(
cls, edge_configs: list[dict[str, object]]
@ -203,6 +157,23 @@ class Graph:
return GraphBuilder(graph_cls=cls)
@staticmethod
def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]:
"""
Remove editor-only nodes before `NodeConfigDict` validation.
Persisted note widgets use a top-level `type == "custom-note"` but leave
`data.type` empty because they are never executable graph nodes. Filter
them while configs are still raw dicts so Pydantic does not validate
their placeholder payloads against `BaseNodeData.type: NodeType`.
"""
filtered_node_configs: list[dict[str, object]] = []
for node_config in node_configs:
if node_config.get("type", "") == "custom-note":
continue
filtered_node_configs.append(dict(node_config))
return filtered_node_configs
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
@ -286,15 +257,15 @@ class Graph:
*,
graph_config: Mapping[str, object],
node_factory: NodeFactory,
root_node_id: str | None = None,
root_node_id: str,
skip_validation: bool = False,
) -> Graph:
"""
Initialize graph
Initialize a graph with an explicit execution entry point.
:param graph_config: graph config containing nodes and edges
:param node_factory: factory for creating node instances from config data
:param root_node_id: root node id
:param root_node_id: active root node id
:return: graph instance
"""
# Parse configs
@ -302,25 +273,25 @@ class Graph:
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
node_configs = cls._filter_canvas_only_nodes(node_configs)
node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
# Filter out UI-only node types:
# - custom-note: top-level type (node_config.type == "custom-note")
# - group: data-level type (node_config.data.type == "group")
node_configs = [
node_config
for node_config in node_configs
if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group"
if node_config.get("data", {}).get("type", "") != "group"
]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)
# Find root node
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Build edges
edges, in_edges, out_edges = cls._build_edges(edge_configs)

View File

@ -4,7 +4,7 @@ from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from dify_graph.enums import NodeExecutionType, NodeType
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@ -71,7 +71,7 @@ class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
@ -86,7 +86,7 @@ class _RootNodeValidator:
)
return issues
node_type = getattr(root_node, "node_type", None)
node_type = root_node.node_type
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
@ -114,45 +114,9 @@ class GraphValidator:
raise GraphValidationError(issues)
@dataclass(frozen=True, slots=True)
class _TriggerStartExclusivityValidator:
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
start_node_id: str | None = None
trigger_node_ids: list[str] = []
for node in graph.nodes.values():
node_type = getattr(node, "node_type", None)
if not isinstance(node_type, NodeType):
continue
if node_type == NodeType.START:
start_node_id = node.id
elif node_type.is_trigger_node:
trigger_node_ids.append(node.id)
if start_node_id and trigger_node_ids:
trigger_list = ", ".join(trigger_node_ids)
return [
GraphValidationIssue(
code=self.conflict_code,
message=(
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
),
node_id=start_node_id,
)
]
return []
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
_TriggerStartExclusivityValidator(),
)

View File

@ -6,5 +6,6 @@ of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
from .session import RESPONSE_SESSION_NODE_TYPES
__all__ = ["ResponseStreamCoordinator"]
__all__ = ["RESPONSE_SESSION_NODE_TYPES", "ResponseStreamCoordinator"]

View File

@ -3,19 +3,34 @@ Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
`RESPONSE_SESSION_NODE_TYPES` is intentionally mutable so downstream applications
can opt additional response-capable node types into session creation without
patching the coordinator.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol, cast
from dify_graph.nodes.answer.answer_node import AnswerNode
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base.template import Template
from dify_graph.nodes.end.end_node import EndNode
from dify_graph.nodes.knowledge_index import KnowledgeIndexNode
from dify_graph.runtime.graph_runtime_state import NodeProtocol
class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
"""Structural contract required from nodes that can open a response session."""
def get_streaming_template(self) -> Template: ...
RESPONSE_SESSION_NODE_TYPES: list[NodeType] = [
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.END,
]
@dataclass
class ResponseSession:
"""
@ -33,10 +48,9 @@ class ResponseSession:
"""
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
- `id: str`
- `get_streaming_template() -> Template`
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
At runtime this must be a node whose `node_type` is listed in `RESPONSE_SESSION_NODE_TYPES`
and which implements `get_streaming_template()`.
Args:
node: Node from the materialized workflow graph.
@ -47,11 +61,22 @@ class ResponseSession:
Raises:
TypeError: If node is not a supported response node type.
"""
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
if node.node_type not in RESPONSE_SESSION_NODE_TYPES:
supported_node_types = ", ".join(RESPONSE_SESSION_NODE_TYPES)
raise TypeError(
"ResponseSession.from_node only supports node types in "
f"RESPONSE_SESSION_NODE_TYPES: {supported_node_types}"
)
response_node = cast(_ResponseSessionNodeProtocol, node)
try:
template = response_node.get_streaming_template()
except AttributeError as exc:
raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc
return cls(
node_id=node.id,
template=node.get_streaming_template(),
template=template,
)
def is_complete(self) -> bool:

View File

@ -5,7 +5,7 @@ from enum import StrEnum
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from dify_graph.entities import ToolCall, ToolResult
from dify_graph.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@ -14,8 +14,8 @@ from .base import GraphNodeEventBase
class NodeRunStartedEvent(GraphNodeEventBase):
node_title: str
predecessor_node_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
start_at: datetime = Field(..., description="node start time")
extras: dict[str, object] = Field(default_factory=dict)
# FIXME(-LAN-): only for ToolNode
provider_type: str = ""

View File

@ -279,5 +279,4 @@ class ToolPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
return super().is_empty() and not self.tool_call_id

View File

@ -4,7 +4,8 @@ class InvokeError(ValueError):
description: str | None = None
def __init__(self, description: str | None = None):
self.description = description
if description is not None:
self.description = description
def __str__(self):
return self.description or self.__class__.__name__

View File

@ -282,7 +282,8 @@ class ModelProviderFactory:
all_model_type_models.append(model_schema)
simple_provider_schema = provider_schema.to_simple_provider()
simple_provider_schema.models.extend(all_model_type_models)
if model_type:
simple_provider_schema.models = all_model_type_models
providers.append(simple_provider_schema)

View File

@ -1,10 +1,10 @@
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import Any
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import ToolCall, ToolResult
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.file import File
@ -15,7 +15,7 @@ from .base import NodeEventBase
class RunRetrieverResourceEvent(NodeEventBase):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
context_files: list[File] | None = Field(default=None, description="context files")

View File

@ -1,3 +1,3 @@
from dify_graph.enums import NodeType
from dify_graph.enums import BuiltinNodeTypes
__all__ = ["NodeType"]
__all__ = ["BuiltinNodeTypes"]

View File

@ -1,3 +0,0 @@
from .agent_node import AgentNode
__all__ = ["AgentNode"]

View File

@ -1,45 +0,0 @@
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.nodes.base.entities import BaseNodeData
class AgentNodeData(BaseNodeData):
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version
# and requires using the legacy parameter parsing rules.
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(IntEnum):
CLOSE = 0
OPEN = 1
class AgentOldVersionModelFeatures(StrEnum):
"""
Enum class for old SDK version llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

View File

@ -1,132 +0,0 @@
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: str | None = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
strategy_name,
provider_name,
)
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: str | None = None):
self.parameter_name = parameter_name
super().__init__(message)
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: str | None = None):
self.variable_name = variable_name
super().__init__(message)
class AgentVariableNotFoundError(AgentVariableError):
"""Exception raised when a variable is not found in the variable pool."""
def __init__(self, variable_name: str):
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
class AgentInputTypeError(AgentNodeError):
"""Exception raised when an unknown agent input type is encountered."""
def __init__(self, input_type: str):
super().__init__(f"Unknown agent input type '{input_type}'")
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: str | None = None):
self.file_id = file_id
super().__init__(message)
class ToolFileNotFoundError(ToolFileError):
"""Exception raised when a tool file is not found."""
def __init__(self, file_id: str):
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: str | None = None):
self.conversation_id = conversation_id
super().__init__(message)
class AgentVariableTypeError(AgentNodeError):
"""Exception raised when a variable has an unexpected type."""
def __init__(
self,
message: str,
variable_name: str | None = None,
expected_type: str | None = None,
actual_type: str | None = None,
):
self.variable_name = variable_name
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.answer.entities import AnswerNodeData
from dify_graph.nodes.base.node import Node
@ -11,7 +11,7 @@ from dify_graph.variables import ArrayFileSegment, FileSegment, Segment
class AnswerNode(Node[AnswerNodeData]):
node_type = NodeType.ANSWER
node_type = BuiltinNodeTypes.ANSWER
execution_type = NodeExecutionType.RESPONSE
@classmethod
@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: AnswerNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AnswerNodeData.model_validate(node_data)
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
_ = graph_config # Explicitly mark as unused
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}

View File

@ -3,7 +3,8 @@ from enum import StrEnum, auto
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
class AnswerNodeData(BaseNodeData):
@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData):
Answer Node Data.
"""
type: NodeType = BuiltinNodeTypes.ANSWER
answer: str = Field(..., description="answer template string")

View File

@ -1,10 +1,4 @@
from .entities import (
BaseIterationNodeData,
BaseIterationState,
BaseLoopNodeData,
BaseLoopState,
BaseNodeData,
)
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [
@ -12,6 +6,5 @@ __all__ = [
"BaseIterationState",
"BaseLoopNodeData",
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
]

View File

@ -1,31 +1,12 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
from typing import Any
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel, field_validator
from dify_graph.enums import ErrorStrategy
from .exc import DefaultValueTypeError
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
from dify_graph.entities.base_node_data import BaseNodeData
class VariableSelector(BaseModel):
@ -76,120 +57,6 @@ class OutputVariableEntity(BaseModel):
return v
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
# Parent node ID when this node is used as an extractor.
# If set, this node is an "attached" extractor node that extracts values
# from list[PromptMessage] for the parent node's parameters.
parent_node_id: str | None = None
@property
def is_extractor_node(self) -> bool:
"""Check if this node is an extractor node (has parent_node_id)."""
return self.parent_node_id is not None
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
class BaseIterationNodeData(BaseNodeData):

View File

@ -1,9 +1,7 @@
from __future__ import annotations
import importlib
import logging
import operator
import pkgutil
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
@ -11,7 +9,9 @@ from types import MappingProxyType
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
from dify_graph.entities import GraphInitParams
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import (
ErrorStrategy,
@ -65,8 +65,6 @@ from dify_graph.node_events import (
from dify_graph.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from .entities import BaseNodeData, RetryConfig
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
_MISSING_RUN_CONTEXT_VALUE = object()
@ -156,15 +154,15 @@ class Node(Generic[NodeDataT]):
Later, in __init__:
::
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
CodeNodeData instance
(stored in self._node_data)
config["data"] ──► _node_data_type.model_validate(..., from_attributes=True)
CodeNodeData instance
(stored in self._node_data)
Example:
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
node_type = NodeType.CODE
node_type = BuiltinNodeTypes.CODE
# No need to implement _get_title, _get_error_strategy, etc.
"""
super().__init_subclass__(**kwargs)
@ -182,7 +180,8 @@ class Node(Generic[NodeDataT]):
# Skip base class itself
if cls is Node:
return
# Only register production node implementations defined under dify_graph.nodes.*
# Only register production node implementations defined under the
# canonical workflow namespaces.
# This prevents test helper subclasses from polluting the global registry and
# accidentally overriding real node types (e.g., a test Answer node).
module_name = getattr(cls, "__module__", "")
@ -190,7 +189,7 @@ class Node(Generic[NodeDataT]):
node_type = cls.node_type
version = cls.version()
bucket = Node._registry.setdefault(node_type, {})
if module_name.startswith("dify_graph.nodes."):
if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")):
# Production node definitions take precedence and may override
bucket[version] = cls # type: ignore[index]
else:
@ -206,6 +205,7 @@ class Node(Generic[NodeDataT]):
else:
latest_key = max(version_keys) if version_keys else version
bucket["latest"] = bucket[latest_key]
Node._registry_version += 1
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
@ -240,11 +240,16 @@ class Node(Generic[NodeDataT]):
# Global registry populated via __init_subclass__
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
_registry_version: ClassVar[int] = 0
@classmethod
def get_registry_version(cls) -> int:
return cls._registry_version
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
@ -257,22 +262,25 @@ class Node(Generic[NodeDataT]):
self.graph_runtime_state = graph_runtime_state
self.state: NodeState = NodeState.UNKNOWN # node execution state
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
node_id = config["id"]
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
raise ValueError("Node config data must be a mapping.")
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
self._node_data = self.validate_node_data(config["data"])
self.post_init()
@classmethod
def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT:
"""Validate shared graph node payloads against the subclass-declared NodeData model."""
return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True))
def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None:
"""Hydrate `_node_data` for legacy callers that bypass `__init__`."""
self._node_data = self.validate_node_data(cast(BaseNodeData, data))
def post_init(self) -> None:
"""Optional hook for subclasses requiring extra initialization."""
return
@ -345,9 +353,6 @@ class Node(Generic[NodeDataT]):
return None
return str(execution_id)
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
@ -357,12 +362,6 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
"""
Find all extractor node configurations that have parent_node_id == self._node_id.
Returns:
List of node configuration dicts for extractor nodes
"""
nodes = self.graph_config.get("nodes", [])
extractor_configs = []
for node_config in nodes:
@ -372,12 +371,6 @@ class Node(Generic[NodeDataT]):
return extractor_configs
def _execute_nested_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
"""
Execute all nested nodes associated with this node.
Nested nodes are nodes with parent_node_id == self._node_id.
They are executed before the main node to extract values from list[PromptMessage].
"""
from core.workflow.node_factory import DifyNodeFactory
extractor_configs = self._find_extractor_node_configs()
@ -411,6 +404,10 @@ class Node(Generic[NodeDataT]):
if not isinstance(event, NodeRunStreamChunkEvent):
yield event
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
"""Allow subclasses to enrich the started event without cross-node imports in the base class."""
_ = event
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@ -427,41 +424,10 @@ class Node(Generic[NodeDataT]):
in_iteration_id=None,
start_at=self._start_at,
)
# === FIXME(-LAN-): Needs to refactor.
from dify_graph.nodes.tool.tool_node import ToolNode
if isinstance(self, ToolNode):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from dify_graph.nodes.datasource.datasource_node import DatasourceNode
if isinstance(self, DatasourceNode):
plugin_id = getattr(self.node_data, "plugin_id", "")
provider_name = getattr(self.node_data, "provider_name", "")
start_event.provider_id = f"{plugin_id}/{provider_name}"
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
if isinstance(self, TriggerEventNode):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from typing import cast
from dify_graph.nodes.agent.agent_node import AgentNode
from dify_graph.nodes.agent.entities import AgentNodeData
if isinstance(self, AgentNode):
start_event.agent_strategy = AgentNodeStrategyInit(
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
icon=self.agent_strategy_icon,
)
# ===
try:
self.populate_start_event(start_event)
except Exception:
logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True)
yield start_event
try:
@ -503,7 +469,7 @@ class Node(Generic[NodeDataT]):
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
config: NodeConfigDict,
) -> Mapping[str, Sequence[str]]:
"""Extracts references variable selectors from node configuration.
@ -541,13 +507,12 @@ class Node(Generic[NodeDataT]):
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
# Pass raw dict data instead of creating NodeData instance
node_id = config["id"]
node_data = cls.validate_node_data(config["data"])
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
graph_config=graph_config,
node_id=node_id,
node_data=node_data,
)
return data
@ -557,7 +522,7 @@ class Node(Generic[NodeDataT]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: NodeDataT,
) -> Mapping[str, Sequence[str]]:
return {}
@ -581,30 +546,20 @@ class Node(Generic[NodeDataT]):
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/dify_graph/nodes/__init__.py`.
# NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so
# registry lookups can resolve numeric versions and `latest`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
"""Return a read-only view of the currently registered node classes.
Import all modules under dify_graph.nodes so subclasses register themselves on import.
Then we return a readonly view of the registry to avoid accidental mutation.
This accessor intentionally performs no imports. The embedding layer that
owns bootstrap (for example `core.workflow.node_factory`) must import any
extension node packages before calling it so their subclasses register via
`__init_subclass__`.
"""
# Import all node modules to ensure they are loaded (thus registered)
import dify_graph.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
# Avoid importing modules that depend on the registry to prevent circular imports.
if _modname == "dify_graph.nodes.node_mapping":
continue
importlib.import_module(_modname)
# Return a readonly view so callers can't mutate the registry by accident
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()}
@property
def retry(self) -> bool:
@ -941,11 +896,16 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
retriever_resources = [
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
]
return NodeRunRetrieverResourceEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,
retriever_resources=retriever_resources,
context=event.context,
node_version=self.version(),
)

View File

@ -3,7 +3,8 @@ from decimal import Decimal
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Protocol, cast
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData
@ -71,13 +72,13 @@ _DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = {
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
node_type = BuiltinNodeTypes.CODE
_limits: CodeNodeLimits
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: CodeNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
for variable_selector in node_data.variables
}
@property

View File

@ -3,7 +3,8 @@ from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.variables.types import SegmentType
@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = BuiltinNodeTypes.CODE
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, "CodeNodeData.Output"] | None = None

View File

@ -1,3 +0,0 @@
from .datasource_node import DatasourceNode
__all__ = ["DatasourceNode"]

View File

@ -1,217 +0,0 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.repositories.datasource_manager_protocol import (
DatasourceManagerProtocol,
DatasourceParameter,
OnlineDriveDownloadFileParam,
)
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
class DatasourceNode(Node[DatasourceNodeData]):
"""
Datasource Node
"""
node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
datasource_manager: DatasourceManagerProtocol,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self.datasource_manager = datasource_manager
def _run(self) -> Generator:
"""
Run the datasource node
"""
dify_ctx = self.require_dify_context()
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segment:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None
datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segment:
raise DatasourceNodeError("Datasource info is not set")
datasource_info_value = datasource_info_segment.value
if not isinstance(datasource_info_value, dict):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = DatasourceProviderType.value_of(datasource_type)
provider_id = f"{node_data.plugin_id}/{node_data.provider_name}"
datasource_info["icon"] = self.datasource_manager.get_icon_url(
provider_id=provider_id,
datasource_name=node_data.datasource_name or "",
tenant_id=dify_ctx.tenant_id,
datasource_type=datasource_type.value,
)
parameters_for_log = datasource_info
try:
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE:
# Build typed request objects
datasource_parameters = None
if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_parameters = DatasourceParameter(
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
)
online_drive_request = None
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
online_drive_request = OnlineDriveDownloadFileParam(
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket", ""),
)
credential_id = datasource_info.get("credential_id", "")
yield from self.datasource_manager.stream_node_events(
node_id=self._node_id,
user_id=dify_ctx.user_id,
datasource_name=node_data.datasource_name or "",
datasource_type=datasource_type.value,
provider_id=provider_id,
tenant_id=dify_ctx.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=credential_id,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
variable_pool=variable_pool,
datasource_param=datasource_parameters,
online_drive_request=online_drive_request,
)
case DatasourceProviderType.WEBSITE_CRAWL:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
**datasource_info,
"datasource_type": datasource_type,
},
)
)
case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
file_info = self.datasource_manager.get_upload_file_by_id(
file_id=related_id, tenant_id=dify_ctx.tenant_id
)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file_info,
"datasource_type": datasource_type,
},
)
)
case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
)
except DatasourceNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
typed_node_data = DatasourceNodeData.model_validate(node_data)
result = {}
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
case "constant":
pass
case None:
pass
result = {node_id + "." + key: value for key, value in result.items()}
return result
@classmethod
def version(cls) -> str:
return "1"

View File

@ -1,41 +0,0 @@
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.nodes.base.entities import BaseNodeData
class DatasourceEntity(BaseModel):
plugin_id: str
provider_name: str # redundancy
provider_type: str
datasource_name: str | None = "local_file"
datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"] | None = None
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
datasource_parameters: dict[str, DatasourceInput] | None = None

View File

@ -1,16 +0,0 @@
class DatasourceNodeError(ValueError):
"""Base exception for datasource node errors."""
pass
class DatasourceParameterError(DatasourceNodeError):
"""Exception raised for errors in datasource parameters."""
pass
class DatasourceFileError(DatasourceNodeError):
"""Exception raised for errors related to datasource files."""
pass

View File

@ -1,10 +1,12 @@
from collections.abc import Sequence
from dataclasses import dataclass
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
class DocumentExtractorNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR
variable_selector: Sequence[str]

View File

@ -4,6 +4,7 @@ import json
import logging
import os
import tempfile
import zipfile
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
@ -20,7 +21,8 @@ from docx.oxml.text.paragraph import CT_P
from docx.table import Table
from docx.text.paragraph import Paragraph
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod, file_manager
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@ -44,7 +46,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
Supports plain text, PDF, and DOC/DOCX files.
"""
node_type = NodeType.DOCUMENT_EXTRACTOR
node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR
@classmethod
def version(cls) -> str:
@ -53,7 +55,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -82,8 +84,18 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
value = variable.value
inputs = {"variable_selector": variable_selector}
if isinstance(value, list):
value = list(filter(lambda x: x, value))
process_data = {"documents": value if isinstance(value, list) else [value]}
if not value:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": ArrayStringSegment(value=[])},
)
try:
if isinstance(value, list):
extracted_text_list = [
@ -111,6 +123,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
@ -124,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
return {node_id + ".files": typed_node_data.variable_selector}
_ = graph_config # Explicitly mark as unused
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(
@ -385,6 +396,32 @@ def parser_docx_part(block, doc: Document, content_items, i):
content_items.append((i, "table", Table(block, doc)))
def _normalize_docx_zip(file_content: bytes) -> bytes:
"""
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
ZIP entry names use backslash (\\) as path separator instead of the forward
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
"word\\document.xml" is never found when python-docx looks for
"word/document.xml", which triggers a KeyError about a missing relationship.
This function rewrites the ZIP in-memory, normalizing all entry names to
use forward slashes without touching any actual document content.
"""
try:
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
out_buf = io.BytesIO()
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
for item in zin.infolist():
data = zin.read(item.filename)
# Normalize backslash path separators to forward slash
item.filename = item.filename.replace("\\", "/")
zout.writestr(item, data)
return out_buf.getvalue()
except zipfile.BadZipFile:
# Not a valid zip — return as-is and let python-docx report the real error
return file_content
def _extract_text_from_docx(file_content: bytes) -> str:
"""
Extract text from a DOCX file.
@ -392,7 +429,15 @@ def _extract_text_from_docx(file_content: bytes) -> str:
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
try:
doc = docx.Document(doc_file)
except Exception as e:
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
# Some DOCX files exported by tools like Evernote on Windows use
# backslash path separators in ZIP entries and/or single-quoted XML
# attributes, both of which break python-docx on Linux. Normalize and retry.
file_content = _normalize_docx_zip(file_content)
doc = docx.Document(io.BytesIO(file_content))
text = []
# Keep track of paragraph and table positions

View File

@ -1,4 +1,4 @@
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.template import Template
@ -6,7 +6,7 @@ from dify_graph.nodes.end.entities import EndNodeData
class EndNode(Node[EndNodeData]):
node_type = NodeType.END
node_type = BuiltinNodeTypes.END
execution_type = NodeExecutionType.RESPONSE
@classmethod

View File

@ -1,6 +1,8 @@
from pydantic import BaseModel, Field
from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base.entities import OutputVariableEntity
class EndNodeData(BaseNodeData):
@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData):
END Node Data.
"""
type: NodeType = BuiltinNodeTypes.END
outputs: list[OutputVariableEntity]

View File

@ -8,7 +8,8 @@ import charset_normalizer
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = BuiltinNodeTypes.HTTP_REQUEST
method: Literal[
"get",
"post",

View File

@ -3,7 +3,8 @@ import mimetypes
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import variable_template_parser
@ -32,12 +33,12 @@ if TYPE_CHECKING:
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
node_type = BuiltinNodeTypes.HTTP_REQUEST
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -163,18 +164,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: HttpRequestNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData.model_validate(node_data)
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if typed_node_data.body:
body_type = typed_node_data.body.type
data = typed_node_data.body.data
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
body_type = node_data.body.type
data = node_data.body.data
match body_type:
case "none":
pass

View File

@ -10,7 +10,8 @@ from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.runtime import VariablePool
from dify_graph.variables.consts import SELECTORS_LENGTH
@ -71,8 +72,8 @@ class EmailDeliveryConfig(BaseModel):
body: str
debug_mode: bool = False
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
if not user_id:
def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig":
if user_id is None:
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
return self.model_copy(update={"recipients": debug_recipients})
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
@ -140,7 +141,7 @@ def apply_debug_email_recipient(
method: DeliveryChannelConfig,
*,
enabled: bool,
user_id: str,
user_id: str | None,
) -> DeliveryChannelConfig:
if not enabled:
return method
@ -148,7 +149,7 @@ def apply_debug_email_recipient(
return method
if not method.config.debug_mode:
return method
debug_config = method.config.with_debug_recipient(user_id or "")
debug_config = method.config.with_debug_recipient(user_id)
return method.model_copy(update={"config": debug_config})
@ -214,6 +215,7 @@ class UserAction(BaseModel):
class HumanInputNodeData(BaseNodeData):
"""Human Input node data."""
type: NodeType = BuiltinNodeTypes.HUMAN_INPUT
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
form_content: str = ""
inputs: list[FormInput] = Field(default_factory=list)

View File

@ -3,8 +3,9 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
from dify_graph.node_events import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
@ -39,7 +40,7 @@ logger = logging.getLogger(__name__)
class HumanInputNode(Node[HumanInputNodeData]):
node_type = NodeType.HUMAN_INPUT
node_type = BuiltinNodeTypes.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
@ -63,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository,
@ -348,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: HumanInputNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
@ -357,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]):
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
validated_node_data = HumanInputNodeData.model_validate(node_data)
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
return node_data.extract_variable_selector_to_variable_mapping(node_id)

View File

@ -2,7 +2,8 @@ from typing import Literal
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.utils.condition.entities import Condition
@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData):
If Else Node Data.
"""
type: NodeType = BuiltinNodeTypes.IF_ELSE
class Case(BaseModel):
"""
Case entity representing a single logical condition group

View File

@ -3,7 +3,7 @@ from typing import Any, Literal
from typing_extensions import deprecated
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.if_else.entities import IfElseNodeData
@ -13,7 +13,7 @@ from dify_graph.utils.condition.processor import ConditionProcessor
class IfElseNode(Node[IfElseNodeData]):
node_type = NodeType.IF_ELSE
node_type = BuiltinNodeTypes.IF_ELSE
execution_type = NodeExecutionType.BRANCH
@classmethod
@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData.model_validate(node_data)
var_mapping: dict[str, list[str]] = {}
for case in typed_node_data.cases or []:
_ = graph_config # Explicitly mark as unused
for case in node_data.cases or []:
for condition in case.conditions:
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
var_mapping[key] = condition.variable_selector

View File

@ -3,7 +3,9 @@ from typing import Any
from pydantic import Field
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState
class ErrorHandleMode(StrEnum):
@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData):
Iteration Node Data.
"""
type: NodeType = BuiltinNodeTypes.ITERATION
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData):
Iteration Start Node Data.
"""
pass
type: NodeType = BuiltinNodeTypes.ITERATION_START
class IterationState(BaseIterationState):

View File

@ -7,9 +7,10 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from typing_extensions import TypeIs
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
BuiltinNodeTypes,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
@ -61,7 +62,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
Iteration Node.
"""
node_type = NodeType.ITERATION
node_type = BuiltinNodeTypes.ITERATION
execution_type = NodeExecutionType.CONTAINER
@classmethod
@ -460,21 +461,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: IterationNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IterationNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
f"{node_id}.input_selector": node_data.iterator_selector,
}
iteration_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
node_config_data = node.get("data", {})
if node_config_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)
@ -487,17 +485,16 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# variable selector to variable mapping
try:
# Get node class
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING:
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
node_type = typed_sub_node_config["data"].type
node_mapping = Node.get_node_type_classes_mapping()
if node_type not in node_mapping:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
node_version = str(typed_sub_node_config["data"].version)
node_cls = node_mapping[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
graph_config=graph_config, config=typed_sub_node_config
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
@ -563,7 +560,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
current_index = index_variable.value
for event in rst:
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START:
continue
if isinstance(event, GraphNodeEventBase):

View File

@ -1,4 +1,4 @@
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.iteration.entities import IterationStartNodeData
@ -9,7 +9,7 @@ class IterationStartNode(Node[IterationStartNodeData]):
Iteration Start Node.
"""
node_type = NodeType.ITERATION_START
node_type = BuiltinNodeTypes.ITERATION_START
@classmethod
def version(cls) -> str:

View File

@ -1,3 +0,0 @@
from .knowledge_index_node import KnowledgeIndexNode
__all__ = ["KnowledgeIndexNode"]

View File

@ -1,162 +0,0 @@
from typing import Literal, Union
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.nodes.base import BaseNodeData
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str
reranking_model_name: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: RetrievalMethod
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class FileInfo(BaseModel):
"""
File Info.
"""
file_id: str
class OnlineDocumentIcon(BaseModel):
"""
Document Icon.
"""
icon_url: str
icon_type: str
icon_emoji: str
class OnlineDocumentInfo(BaseModel):
"""
Online document info.
"""
provider: str
workspace_id: str | None = None
page_id: str
page_type: str
icon: OnlineDocumentIcon | None = None
class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
url: str
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunks: list[str]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_contents: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class KnowledgeIndexNodeData(BaseNodeData):
"""
Knowledge index Node Data.
"""
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None

View File

@ -1,22 +0,0 @@
class KnowledgeIndexNodeError(ValueError):
"""Base class for KnowledgeIndexNode errors."""
class ModelNotExistError(KnowledgeIndexNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeIndexNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeIndexNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeIndexNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -1,153 +0,0 @@
import logging
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.template import Template
from dify_graph.repositories.index_processor_protocol import IndexProcessorProtocol
from dify_graph.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol
from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
)
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
_INVOKE_FROM_DEBUGGER = "debugger"
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
node_type = NodeType.KNOWLEDGE_INDEX
execution_type = NodeExecutionType.RESPONSE
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
index_processor: IndexProcessorProtocol,
summary_index_service: SummaryIndexServiceProtocol,
) -> None:
super().__init__(id, config, graph_init_params, graph_runtime_state)
self.index_processor = index_processor
self.summary_index_service = summary_index_service
def _run(self) -> NodeRunResult: # type: ignore
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
# get dataset id as string
dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id_segment:
raise KnowledgeIndexNodeError("Dataset ID is required.")
dataset_id: str = dataset_id_segment.value
# get document id as string (may be empty when not provided)
document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
document_id: str = document_id_segment.value if document_id_segment else ""
# extract variables
variable = variable_pool.get(node_data.index_chunk_variable_selector)
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
invoke_from_value = str(invoke_from.value) if invoke_from else None
is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
try:
summary_index_setting = node_data.summary_index_setting
if is_preview:
# Preview mode: generate summaries for chunks directly without saving to database
# Format preview and generate summaries on-the-fly
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
# or fallback to dataset if not available in node_data
outputs = self.index_processor.get_preview_output(
chunks, dataset_id, document_id, node_data.chunk_structure, summary_index_setting
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=outputs.model_dump(exclude_none=True),
)
original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
results = self._invoke_knowledge_index(
dataset_id=dataset_id,
document_id=document_id,
original_document_id=original_document_id_segment.value if original_document_id_segment else "",
is_preview=is_preview,
batch=batch.value,
chunks=chunks,
summary_index_setting=summary_index_setting,
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node", exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
except Exception as e:
logger.error(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _invoke_knowledge_index(
self,
dataset_id: str,
document_id: str,
original_document_id: str,
is_preview: bool,
batch: Any,
chunks: Mapping[str, Any],
summary_index_setting: dict | None = None,
):
if not document_id:
raise KnowledgeIndexNodeError("document_id is required.")
rst = self.index_processor.index_and_clean(
dataset_id, document_id, original_document_id, chunks, batch, summary_index_setting
)
self.summary_index_service.generate_and_vectorize_summary(
dataset_id, document_id, is_preview, summary_index_setting
)
return rst
@classmethod
def version(cls) -> str:
return "1"
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this knowledge index node
"""
return Template(segments=[])

View File

@ -1,3 +0,0 @@
from .knowledge_retrieval_node import KnowledgeRetrievalNode
__all__ = ["KnowledgeRetrievalNode"]

View File

@ -1,135 +0,0 @@
from collections.abc import Sequence
from typing import Literal
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
provider: str
model: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class MultipleRetrievalConfig(BaseModel):
"""
Multiple Retrieval Config.
"""
top_k: int
score_threshold: float | None = None
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
"""
model: ModelConfig
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = "knowledge-retrieval"
query_variable_selector: list[str] | None | str = None
query_attachment_selector: list[str] | None | str = None
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: MultipleRetrievalConfig | None = None
single_retrieval_config: SingleRetrievalConfig | None = None
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property
def structured_output_enabled(self) -> bool:
# NOTE(QuantumGhost): Temporary workaround for issue #20725
# (https://github.com/langgenius/dify/issues/20725).
#
# The proper fix would be to make `KnowledgeRetrievalNode` inherit
# from `BaseNode` instead of `LLMNode`.
return False

View File

@ -1,26 +0,0 @@
class KnowledgeRetrievalNodeError(ValueError):
"""Base class for KnowledgeRetrievalNode errors."""
class ModelNotExistError(KnowledgeRetrievalNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeRetrievalNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""
class RateLimitExceededError(KnowledgeRetrievalNodeError):
"""Raised when the rate limit is exceeded."""

View File

@ -1,276 +0,0 @@
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from dify_graph.entities import GraphInitParams
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
from dify_graph.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
)
from dify_graph.variables.segments import ArrayObjectSegment
from .entities import KnowledgeRetrievalNodeData
from .exc import (
KnowledgeRetrievalNodeError,
RateLimitExceededError,
)
if TYPE_CHECKING:
from dify_graph.file.models import File
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
*,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls):
return "1"
def _run(self) -> NodeRunResult:
usage = LLMUsage.empty_usage()
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={},
llm_usage=usage,
)
variables: dict[str, Any] = {}
# extract variables
if self._node_data.query_variable_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables["query"] = query
if self._node_data.query_attachment_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector)
if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Attachments variable is not array file or file type.",
)
if isinstance(variable, ArrayFileSegment):
variables["attachments"] = variable.value
else:
variables["attachments"] = [variable.value]
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data={"usage": jsonable_encoder(usage)},
outputs=outputs, # type: ignore
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except RateLimitExceededError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node", exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[Source], LLMUsage]:
dify_ctx = self.require_dify_context()
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
retrieval_resource_list = []
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required for single retrieval mode")
model = node_data.single_retrieval_config.model
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
user_from=dify_ctx.user_from.value,
dataset_ids=dataset_ids,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
completion_params=model.completion_params,
model_provider=model.provider,
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
)
)
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
reranking_model = None
weights = None
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
case "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
case _:
# Handle any other reranking_mode values
reranking_model = None
weights = None
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
app_id=dify_ctx.app_id,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
user_from=dify_ctx.user_from.value,
dataset_ids=dataset_ids,
query=query,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
)
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
if typed_node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
return variable_mapping

View File

@ -1,66 +0,0 @@
METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "", ">", "<", "", "", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501
METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which companys email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""
METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""
METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""
METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which companys email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

View File

@ -3,7 +3,8 @@ from enum import StrEnum
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
class FilterOperator(StrEnum):
@ -62,6 +63,7 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.LIST_OPERATOR
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderByConfig

View File

@ -1,7 +1,7 @@
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, TypeVar
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.file import File
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@ -35,7 +35,7 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
class ListOperatorNode(Node[ListOperatorNodeData]):
node_type = NodeType.LIST_OPERATOR
node_type = BuiltinNodeTypes.LIST_OPERATOR
@classmethod
def version(cls) -> str:

View File

@ -8,11 +8,12 @@ from core.agent.entities import AgentLog, AgentResult
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.entities import ToolCall, ToolCallResult
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.file import File
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import AgentLogEvent
from dify_graph.nodes.base import BaseNodeData
from dify_graph.nodes.base.entities import VariableSelector
@ -367,6 +368,7 @@ class ToolSetting(BaseModel):
class LLMNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.LLM
model: ModelConfig
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)

View File

@ -1,14 +1,11 @@
import mimetypes
import typing as tp
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db as global_db
from dify_graph.nodes.protocols import HttpClientProtocol
class LLMFileSaver(tp.Protocol):
@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol):
raise NotImplementedError()
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
return global_db.engine
engine_factory = _factory
self._engine_factory = engine_factory
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
self._user_id = user_id
self._tenant_id = tenant_id
self._http_client = http_client
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
return ToolFileManager()
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response = self._http_client.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")

View File

@ -46,8 +46,10 @@ from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.tool_entities import ToolCallResult
from dify_graph.enums import (
BuiltinNodeTypes,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
@ -95,6 +97,7 @@ from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
@ -146,7 +149,7 @@ logger = logging.getLogger(__name__)
class LLMNode(Node[LLMNodeData]):
node_type = NodeType.LLM
node_type = BuiltinNodeTypes.LLM
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
@ -164,13 +167,14 @@ class LLMNode(Node[LLMNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@ -193,6 +197,7 @@ class LLMNode(Node[LLMNodeData]):
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver
@ -1220,7 +1225,7 @@ class LLMNode(Node[LLMNodeData]):
)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource: list[RetrievalSourceMetadata] = []
original_retriever_resource: list[dict[str, Any]] = []
context_files: list[File] = []
for item in context_value_variable.value:
if isinstance(item, str):
@ -1236,11 +1241,14 @@ class LLMNode(Node[LLMNodeData]):
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
segment_id = retriever_resource.get("segment_id")
if not segment_id:
continue
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
SegmentAttachmentBinding.segment_id == segment_id,
)
).all()
if attachments_with_bindings:
@ -1266,7 +1274,7 @@ class LLMNode(Node[LLMNodeData]):
context_files=context_files,
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None:
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@ -1274,28 +1282,26 @@ class LLMNode(Node[LLMNodeData]):
):
metadata = context_dict.get("metadata", {})
source = RetrievalSourceMetadata(
position=metadata.get("position"),
dataset_id=metadata.get("dataset_id"),
dataset_name=metadata.get("dataset_name"),
document_id=metadata.get("document_id"),
document_name=metadata.get("document_name"),
data_source_type=metadata.get("data_source_type"),
segment_id=metadata.get("segment_id"),
retriever_from=metadata.get("retriever_from"),
score=metadata.get("score"),
hit_count=metadata.get("segment_hit_count"),
word_count=metadata.get("segment_word_count"),
segment_position=metadata.get("segment_position"),
index_node_hash=metadata.get("segment_index_node_hash"),
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
summary=context_dict.get("summary"),
)
return source
return {
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"),
"files": context_dict.get("files"),
"summary": context_dict.get("summary"),
}
return None
@ -1503,14 +1509,11 @@ class LLMNode(Node[LLMNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template
prompt_template = node_data.prompt_template
variable_selectors = []
prompt_context_selectors: list[Sequence[str]] = []
if isinstance(prompt_template, list):
@ -1538,7 +1541,7 @@ class LLMNode(Node[LLMNodeData]):
variable_key = f"#{'.'.join(context_selector)}#"
variable_mapping[variable_key] = list(context_selector)
memory = typed_node_data.memory
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@ -1546,16 +1549,16 @@ class LLMNode(Node[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if typed_node_data.context.enabled:
variable_mapping["#context#"] = typed_node_data.context.variable_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if typed_node_data.vision.enabled:
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if typed_node_data.memory:
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
if typed_node_data.prompt_config:
if node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
@ -1567,7 +1570,7 @@ class LLMNode(Node[LLMNodeData]):
enable_jinja = True
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}

View File

@ -3,7 +3,9 @@ from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
@ -39,6 +41,7 @@ class LoopVariableData(BaseModel):
class LoopNodeData(BaseLoopNodeData):
type: NodeType = BuiltinNodeTypes.LOOP
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData):
Loop Start Node Data.
"""
pass
type: NodeType = BuiltinNodeTypes.LOOP_START
class LoopEndNodeData(BaseNodeData):
@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData):
Loop End Node Data.
"""
pass
type: NodeType = BuiltinNodeTypes.LOOP_END
class LoopState(BaseLoopState):

View File

@ -1,4 +1,4 @@
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopEndNodeData
@ -9,7 +9,7 @@ class LoopEndNode(Node[LoopEndNodeData]):
Loop End Node.
"""
node_type = NodeType.LOOP_END
node_type = BuiltinNodeTypes.LOOP_END
@classmethod
def version(cls) -> str:

View File

@ -5,9 +5,10 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
BuiltinNodeTypes,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
@ -45,7 +46,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
Loop Node.
"""
node_type = NodeType.LOOP
node_type = BuiltinNodeTypes.LOOP
execution_type = NodeExecutionType.CONTAINER
@classmethod
@ -249,11 +250,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if isinstance(event, GraphNodeEventBase):
self._append_loop_info_to_event(event=event, loop_run_index=current_index)
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START:
if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.LOOP_START:
continue
if isinstance(event, GraphNodeEventBase):
yield event
if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END:
if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END:
reach_break_node = True
if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error)
@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = LoopNodeData.model_validate(node_data)
variable_mapping = {}
# Extract loop node IDs statically from graph_config
@ -317,17 +315,16 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
# variable selector to variable mapping
try:
# Get node class
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING:
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
node_type = typed_sub_node_config["data"].type
node_mapping = Node.get_node_type_classes_mapping()
if node_type not in node_mapping:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
node_version = str(typed_sub_node_config["data"].version)
node_cls = node_mapping[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
graph_config=graph_config, config=typed_sub_node_config
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
@ -342,7 +339,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in typed_node_data.loop_variables or []:
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping

View File

@ -1,4 +1,4 @@
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopStartNodeData
@ -9,7 +9,7 @@ class LoopStartNode(Node[LoopStartNodeData]):
Loop Start Node.
"""
node_type = NodeType.LOOP_START
node_type = BuiltinNodeTypes.LOOP_START
@classmethod
def version(cls) -> str:

View File

@ -1,9 +0,0 @@
from collections.abc import Mapping
from dify_graph.enums import NodeType
from dify_graph.nodes.base.node import Node
LATEST_VERSION = "latest"
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks dify_graph.nodes
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()

View File

@ -8,7 +8,8 @@ from pydantic import (
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
from dify_graph.variables.types import SegmentType
@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData):
Parameter Extractor Node Data.
"""
type: NodeType = BuiltinNodeTypes.PARAMETER_EXTRACTOR
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]

View File

@ -10,8 +10,9 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
@ -96,7 +97,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
Parameter Extractor Node.
"""
node_type = NodeType.PARAMETER_EXTRACTOR
node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -842,15 +843,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: ParameterExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
if node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector

View File

@ -1,8 +1,10 @@
from collections.abc import Generator
from typing import Any, Protocol
import httpx
from dify_graph.file import File
from dify_graph.file.models import ToolFile
class HttpClientProtocol(Protocol):
@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol):
mimetype: str,
filename: str | None = None,
) -> Any: ...
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ...

View File

@ -1,7 +1,8 @@
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.llm import ModelConfig, VisionConfig
@ -11,6 +12,7 @@ class ClassConfig(BaseModel):
class QuestionClassifierNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.QUESTION_CLASSIFIER
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]

View File

@ -7,9 +7,10 @@ from core.model_manager import ModelInstance
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
BuiltinNodeTypes,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
@ -28,6 +29,7 @@ from dify_graph.nodes.llm import (
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@ -48,7 +50,7 @@ if TYPE_CHECKING:
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = NodeType.QUESTION_CLASSIFIER
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_file_outputs: list["File"]
@ -61,13 +63,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@ -90,6 +93,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver
@ -252,16 +256,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: QuestionClassifierNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector}
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors: list[VariableSelector] = []
if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)

View File

@ -2,7 +2,8 @@ from collections.abc import Sequence
from pydantic import Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.variables.input_entities import VariableEntity
@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData):
Start Node Data
"""
type: NodeType = BuiltinNodeTypes.START
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@ -3,7 +3,7 @@ from typing import Any
from jsonschema import Draft7Validator, ValidationError
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.start.entities import StartNodeData
@ -11,7 +11,7 @@ from dify_graph.variables.input_entities import VariableEntityType
class StartNode(Node[StartNodeData]):
node_type = NodeType.START
node_type = BuiltinNodeTypes.START
execution_type = NodeExecutionType.ROOT
@classmethod

View File

@ -1,4 +1,5 @@
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base.entities import VariableSelector
@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData):
Template Transform Node Data.
"""
type: NodeType = BuiltinNodeTypes.TEMPLATE_TRANSFORM
variables: list[VariableSelector]
template: str

View File

@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData
@ -18,14 +19,14 @@ DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
_max_output_length: int
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -86,12 +87,9 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
for variable_selector in node_data.variables
}

View File

@ -6,7 +6,8 @@ from pydantic import BaseModel, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
# Pattern to match mention format: {{@node.context@}}instruction
MENTION_VALUE_PATTERN = re.compile(r"^\{\{@([a-zA-Z0-9_]+)\.context@\}\}(.*)$", re.DOTALL)
@ -69,6 +70,8 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = BuiltinNodeTypes.TOOL
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]

View File

@ -2,19 +2,15 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
logger = logging.getLogger(__name__)
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
BuiltinNodeTypes,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
@ -24,11 +20,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
from dify_graph.variables.variables import ArrayAnyVariable
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData, is_variable_format
@ -39,7 +34,8 @@ from .exc import (
)
if TYPE_CHECKING:
from dify_graph.runtime import VariablePool
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState, VariablePool
class ToolNode(Node[ToolNodeData]):
@ -47,12 +43,33 @@ class ToolNode(Node[ToolNodeData]):
Tool Node
"""
node_type = NodeType.TOOL
node_type = BuiltinNodeTypes.TOOL
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
tool_file_manager_factory: ToolFileManagerProtocol,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._tool_file_manager_factory = tool_file_manager_factory
@classmethod
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
event.provider_id = self.node_data.provider_id
event.provider_type = self.node_data.provider_type
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node
@ -296,11 +313,9 @@ class ToolNode(Node[ToolNodeData]):
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not found")
mapping = {
"tool_file_id": tool_file_id,
@ -319,11 +334,9 @@ class ToolNode(Node[ToolNodeData]):
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
@ -499,7 +512,7 @@ class ToolNode(Node[ToolNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: ToolNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping.
@ -514,9 +527,7 @@ class ToolNode(Node[ToolNodeData]):
:param node_data: node data
:return: mapping of variable key to variable selector
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
typed_node_data = node_data
result: dict[str, Sequence[str]] = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]

View File

@ -1,3 +0,0 @@
from .trigger_event_node import TriggerEventNode
__all__ = ["TriggerEventNode"]

View File

@ -1,77 +0,0 @@
from collections.abc import Mapping
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.trigger.entities.entities import EventParameter
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError
class TriggerEventNodeData(BaseNodeData):
"""Plugin trigger node data"""
class TriggerEventInput(BaseModel):
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
type = value
value = validation_info.data.get("value")
if value is None:
return type
if type == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
if type == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
if type == "constant" and not isinstance(value, str | int | float | bool | dict | list):
raise ValueError("value must be a string, int, float, bool or dict")
return type
title: str
desc: str | None = None
plugin_id: str = Field(..., description="Plugin ID")
provider_id: str = Field(..., description="Provider ID")
event_name: str = Field(..., description="Event name")
subscription_id: str = Field(..., description="Subscription ID")
plugin_unique_identifier: str = Field(..., description="Plugin unique identifier")
event_parameters: Mapping[str, TriggerEventInput] = Field(default_factory=dict, description="Trigger parameters")
def resolve_parameters(
self,
*,
parameter_schemas: Mapping[str, EventParameter],
) -> Mapping[str, Any]:
"""
Generate parameters based on the given plugin trigger parameters.
Args:
parameter_schemas (Mapping[str, EventParameter]): The mapping of parameter schemas.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
result: dict[str, Any] = {}
for parameter_name in self.event_parameters:
parameter: EventParameter | None = parameter_schemas.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
event_input = self.event_parameters[parameter_name]
# trigger node only supports constant input
if event_input.type != "constant":
raise TriggerEventParameterError(f"Unknown plugin trigger input type '{event_input.type}'")
result[parameter_name] = event_input.value
return result

View File

@ -1,10 +0,0 @@
class TriggerEventNodeError(ValueError):
"""Base exception for plugin trigger node errors."""
pass
class TriggerEventParameterError(TriggerEventNodeError):
"""Exception raised for errors in plugin trigger parameters."""
pass

View File

@ -1,64 +0,0 @@
from collections.abc import Mapping
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from .entities import TriggerEventNodeData
class TriggerEventNode(Node[TriggerEventNodeData]):
node_type = NodeType.TRIGGER_PLUGIN
execution_type = NodeExecutionType.ROOT
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "plugin",
"config": {
"title": "",
"plugin_id": "",
"provider_id": "",
"event_name": "",
"subscription_id": "",
"plugin_unique_identifier": "",
"event_parameters": {},
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the plugin trigger node.
This node invokes the trigger to convert request data into events
and makes them available to downstream nodes.
"""
# Get trigger data passed when workflow was triggered
metadata = {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self.node_data.provider_id,
"event_name": self.node_data.event_name,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
},
}
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=outputs,
metadata=metadata,
)

View File

@ -1,3 +0,0 @@
from dify_graph.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
__all__ = ["TriggerScheduleNode"]

View File

@ -1,49 +0,0 @@
from typing import Literal, Union
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
class TriggerScheduleNodeData(BaseNodeData):
"""
Trigger Schedule Node Data
"""
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
visual_config: dict | None = Field(default=None, description="Visual configuration details")
timezone: str = Field(default="UTC", description="Timezone for schedule execution")
class ScheduleConfig(BaseModel):
node_id: str
cron_expression: str
timezone: str = "UTC"
class SchedulePlanUpdate(BaseModel):
node_id: str | None = None
cron_expression: str | None = None
timezone: str | None = None
class VisualConfig(BaseModel):
"""Visual configuration for schedule trigger"""
# For hourly frequency
on_minute: int | None = Field(default=0, ge=0, le=59, description="Minute of the hour (0-59)")
# For daily, weekly, monthly frequencies
time: str | None = Field(default="12:00 AM", description="Time in 12-hour format (e.g., '2:30 PM')")
# For weekly frequency
weekdays: list[Literal["sun", "mon", "tue", "wed", "thu", "fri", "sat"]] | None = Field(
default=None, description="List of weekdays to run on"
)
# For monthly frequency
monthly_days: list[Union[int, Literal["last"]]] | None = Field(
default=None, description="Days of month to run on (1-31 or 'last')"
)

View File

@ -1,31 +0,0 @@
from dify_graph.nodes.base.exc import BaseNodeError
class ScheduleNodeError(BaseNodeError):
"""Base schedule node error."""
pass
class ScheduleNotFoundError(ScheduleNodeError):
"""Schedule not found error."""
pass
class ScheduleConfigError(ScheduleNodeError):
"""Schedule configuration error."""
pass
class ScheduleExecutionError(ScheduleNodeError):
"""Schedule execution error."""
pass
class TenantOwnerNotFoundError(ScheduleExecutionError):
"""Tenant owner not found error for schedule execution."""
pass

View File

@ -1,44 +0,0 @@
from collections.abc import Mapping
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.trigger_schedule.entities import TriggerScheduleNodeData
class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
node_type = NodeType.TRIGGER_SCHEDULE
execution_type = NodeExecutionType.ROOT
@classmethod
def version(cls) -> str:
return "1"
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "trigger-schedule",
"config": {
"mode": "visual",
"frequency": "daily",
"visual_config": {"time": "12:00 AM", "on_minute": 0, "weekdays": ["sun"], "monthly_days": [1]},
"timezone": "UTC",
},
}
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=outputs,
)

View File

@ -1,3 +0,0 @@
from .node import TriggerWebhookNode
__all__ = ["TriggerWebhookNode"]

View File

@ -1,79 +0,0 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel, Field, field_validator
from dify_graph.nodes.base import BaseNodeData
class Method(StrEnum):
GET = "get"
POST = "post"
HEAD = "head"
PATCH = "patch"
PUT = "put"
DELETE = "delete"
class ContentType(StrEnum):
JSON = "application/json"
FORM_DATA = "multipart/form-data"
FORM_URLENCODED = "application/x-www-form-urlencoded"
TEXT = "text/plain"
BINARY = "application/octet-stream"
class WebhookParameter(BaseModel):
"""Parameter definition for headers, query params, or body."""
name: str
required: bool = False
class WebhookBodyParameter(BaseModel):
"""Body parameter with type information."""
name: str
type: Literal[
"string",
"number",
"boolean",
"object",
"array[string]",
"array[number]",
"array[boolean]",
"array[object]",
"file",
] = "string"
required: bool = False
class WebhookData(BaseNodeData):
"""
Webhook Node Data.
"""
class SyncMode(StrEnum):
SYNC = "async" # only support
method: Method = Method.GET
content_type: ContentType = Field(default=ContentType.JSON)
headers: Sequence[WebhookParameter] = Field(default_factory=list)
params: Sequence[WebhookParameter] = Field(default_factory=list) # query parameters
body: Sequence[WebhookBodyParameter] = Field(default_factory=list)
@field_validator("method", mode="before")
@classmethod
def normalize_method(cls, v) -> str:
"""Normalize HTTP method to lowercase to support both uppercase and lowercase input."""
if isinstance(v, str):
return v.lower()
return v
status_code: int = 200 # Expected status code for response
response_body: str = "" # Template for response body
# Webhook specific fields (not from client data, set internally)
webhook_id: str | None = None # Set when webhook trigger is created
timeout: int = 30 # Timeout in seconds to wait for webhook response

View File

@ -1,25 +0,0 @@
from dify_graph.nodes.base.exc import BaseNodeError
class WebhookNodeError(BaseNodeError):
"""Base webhook node error."""
pass
class WebhookTimeoutError(WebhookNodeError):
"""Webhook timeout error."""
pass
class WebhookNotFoundError(WebhookNodeError):
"""Webhook not found error."""
pass
class WebhookConfigError(WebhookNodeError):
"""Webhook configuration error."""
pass

View File

@ -1,176 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType
from dify_graph.file import FileTransferMethod
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.variables.types import SegmentType
from dify_graph.variables.variables import FileVariable
from factories import file_factory
from factories.variable_factory import build_segment_with_type
from .entities import ContentType, WebhookData
logger = logging.getLogger(__name__)
class TriggerWebhookNode(Node[WebhookData]):
node_type = NodeType.TRIGGER_WEBHOOK
execution_type = NodeExecutionType.ROOT
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "webhook",
"config": {
"method": "get",
"content_type": "application/json",
"headers": [],
"params": [],
"body": [],
"async_mode": True,
"status_code": 200,
"response_body": "",
"timeout": 30,
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the webhook node.
Like the start node, this simply takes the webhook data from the variable pool
and makes it available to downstream nodes. The actual webhook handling
happens in the trigger controller.
"""
# Get webhook data from variable pool (injected by Celery task)
webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
# Extract webhook-specific outputs based on node configuration
outputs = self._extract_configured_outputs(webhook_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=webhook_inputs,
outputs=outputs,
)
def generate_file_var(self, param_name: str, file: dict):
dify_ctx = self.require_dify_context()
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL:
file["upload_file_id"] = related_id
case FileTransferMethod.TOOL_FILE:
file["tool_file_id"] = related_id
case FileTransferMethod.DATASOURCE_FILE:
file["datasource_file_id"] = related_id
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=dify_ctx.tenant_id,
)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
except ValueError:
logger.error(
"Failed to build FileVariable for webhook file parameter %s",
param_name,
exc_info=True,
)
return None
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
"""Extract outputs based on node configuration from webhook inputs."""
outputs = {}
# Get the raw webhook data (should be injected by Celery task)
webhook_data = webhook_inputs.get("webhook_data", {})
def _to_sanitized(name: str) -> str:
return name.replace("-", "_")
def _get_normalized(mapping: dict[str, Any], key: str) -> Any:
if not isinstance(mapping, dict):
return None
if key in mapping:
return mapping[key]
alternate = key.replace("-", "_") if "-" in key else key.replace("_", "-")
if alternate in mapping:
return mapping[alternate]
return None
# Extract configured headers (case-insensitive)
webhook_headers = webhook_data.get("headers", {})
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
for header in self.node_data.headers:
header_name = header.name
value = _get_normalized(webhook_headers, header_name)
if value is None:
value = _get_normalized(webhook_headers_lower, header_name.lower())
sanitized_name = _to_sanitized(header_name)
outputs[sanitized_name] = value
# Extract configured query parameters
for param in self.node_data.params:
param_name = param.name
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
# Extract configured body parameters
for body_param in self.node_data.body:
param_name = body_param.name
param_type = body_param.type
if self.node_data.content_type == ContentType.TEXT:
# For text/plain, the entire body is a single string parameter
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self.node_data.content_type == ContentType.BINARY:
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
file_var = self.generate_file_var(param_name, raw_data)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = raw_data
continue
if param_type == "file":
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data
return outputs

View File

@ -1,6 +1,7 @@
from pydantic import BaseModel
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.variables.types import SegmentType
@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData):
Variable Aggregator Node Data.
"""
type: NodeType = BuiltinNodeTypes.VARIABLE_AGGREGATOR
output_type: str
variables: list[list[str]]
advanced_settings: AdvancedSettings | None = None

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_aggregator.entities import VariableAggregatorNodeData
@ -8,7 +8,7 @@ from dify_graph.variables.segments import Segment
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
node_type = NodeType.VARIABLE_AGGREGATOR
node_type = BuiltinNodeTypes.VARIABLE_AGGREGATOR
@classmethod
def version(cls) -> str:

View File

@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
@ -17,12 +18,12 @@ if TYPE_CHECKING:
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
@ -52,21 +53,18 @@ class VariableAssignerNode(Node[VariableAssignerData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: VariableAssignerData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerData.model_validate(node_data)
mapping = {}
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
assigned_variable_node_id = node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(typed_node_data.assigned_variable_selector)
selector_key = ".".join(node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.assigned_variable_selector
mapping[key] = node_data.assigned_variable_selector
selector_key = ".".join(typed_node_data.input_variable_selector)
selector_key = ".".join(node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.input_variable_selector
mapping[key] = node_data.input_variable_selector
return mapping
def _run(self) -> NodeRunResult:

View File

@ -1,7 +1,8 @@
from collections.abc import Sequence
from enum import StrEnum
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
class WriteMode(StrEnum):
@ -11,6 +12,7 @@ class WriteMode(StrEnum):
class VariableAssignerData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]

View File

@ -3,7 +3,8 @@ from typing import Any
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from .enums import InputType, Operation
@ -22,5 +23,6 @@ class VariableOperationItem(BaseModel):
class VariableAssignerNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER
version: str = "2"
items: Sequence[VariableOperationItem] = Field(default_factory=list)

View File

@ -3,7 +3,8 @@ from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
@ -51,12 +52,12 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
@ -94,13 +95,10 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: VariableAssignerNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
var_mapping: dict[str, Sequence[str]] = {}
for item in typed_node_data.items:
for item in node_data.items:
_target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item)
return var_mapping

View File

@ -1,50 +0,0 @@
from collections.abc import Generator
from typing import Any, Protocol
from pydantic import BaseModel
from dify_graph.file import File
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
class DatasourceParameter(BaseModel):
workspace_id: str
page_id: str
type: str
class OnlineDriveDownloadFileParam(BaseModel):
id: str
bucket: str
class DatasourceFinal(BaseModel):
data: dict[str, Any] | None = None
class DatasourceManagerProtocol(Protocol):
@classmethod
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: ...
@classmethod
def stream_node_events(
cls,
*,
node_id: str,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: Any,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: ...
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: ...

View File

@ -1,41 +0,0 @@
from collections.abc import Mapping
from typing import Any, Protocol
from pydantic import BaseModel, Field
class PreviewItem(BaseModel):
content: str | None = Field(None)
child_chunks: list[str] | None = Field(None)
summary: str | None = Field(None)
class QaPreview(BaseModel):
answer: str | None = Field(None)
question: str | None = Field(None)
class Preview(BaseModel):
chunk_structure: str
parent_mode: str | None = Field(None)
preview: list[PreviewItem] = Field([])
qa_preview: list[QaPreview] = Field([])
total_segments: int
class IndexProcessorProtocol(Protocol):
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview: ...
def index_and_clean(
self,
dataset_id: str,
document_id: str,
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
) -> dict[str, Any]: ...
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
) -> Preview: ...

View File

@ -1,108 +0,0 @@
from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from dify_graph.model_runtime.entities import LLMUsage
from dify_graph.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
from dify_graph.nodes.llm.entities import ModelConfig
class SourceChildChunk(BaseModel):
id: str = Field(default="", description="Child chunk ID")
content: str = Field(default="", description="Child chunk content")
position: int = Field(default=0, description="Child chunk position")
score: float = Field(default=0.0, description="Child chunk relevance score")
class SourceMetadata(BaseModel):
source: str = Field(
default="knowledge",
serialization_alias="_source",
description="Data source identifier",
)
dataset_id: str = Field(description="Dataset unique identifier")
dataset_name: str = Field(description="Dataset display name")
document_id: str = Field(description="Document unique identifier")
document_name: str = Field(description="Document display name")
data_source_type: str = Field(description="Type of data source")
segment_id: str | None = Field(default=None, description="Segment unique identifier")
retriever_from: str = Field(default="workflow", description="Retriever source context")
score: float = Field(default=0.0, description="Retrieval relevance score")
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
segment_position: int | None = Field(default=0, description="Position of segment in document")
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
position: int | None = Field(default=0, description="Position of the document in the dataset")
class Config:
populate_by_name = True
class Source(BaseModel):
metadata: SourceMetadata = Field(description="Source metadata information")
title: str = Field(description="Document title")
files: list[Any] | None = Field(default=None, description="Associated file references")
content: str | None = Field(description="Segment content text")
summary: str | None = Field(default=None, description="Content summary if available")
class KnowledgeRetrievalRequest(BaseModel):
tenant_id: str = Field(description="Tenant unique identifier")
user_id: str = Field(description="User unique identifier")
app_id: str = Field(description="Application unique identifier")
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
completion_params: dict[str, Any] | None = Field(
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
)
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
metadata_model_config: ModelConfig | None = Field(
default=None, description="Model config for metadata-based filtering"
)
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
default=None, description="Conditions for filtering by metadata"
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
)
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
class RAGRetrievalProtocol(Protocol):
"""Protocol for RAG-based knowledge retrieval implementations.
Implementations of this protocol handle knowledge retrieval from datasets
including rate limiting, dataset filtering, and document retrieval.
"""
@property
def llm_usage(self) -> LLMUsage:
"""Return accumulated LLM usage for retrieval operations."""
...
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
"""Retrieve knowledge from datasets based on the provided request.
Args:
request: Knowledge retrieval request with search parameters
Returns:
List of sources matching the search criteria
Raises:
RateLimitExceededError: If rate limit is exceeded
ModelNotExistError: If specified model doesn't exist
"""
...

View File

@ -1,7 +0,0 @@
from typing import Protocol
class SummaryIndexServiceProtocol(Protocol):
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
): ...

View File

@ -65,9 +65,15 @@ class VariablePool(BaseModel):
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
# Add conversation variables to the variable pool. When restoring from a serialized
# snapshot, `variable_dictionary` already carries the latest runtime values.
# In that case, keep existing entries instead of overwriting them with the
# bootstrap list.
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
selector = (CONVERSATION_VARIABLE_NODE_ID, var.name)
if self._has(selector):
continue
self.add(selector, var)
# Add rag pipeline variables to the variable pool
if self.rag_pipeline_variables:
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)