Merge commit 'fb41b215' into sandboxed-agent-rebase

Made-with: Cursor

# Conflicts:
#	.devcontainer/post_create_command.sh
#	api/commands.py
#	api/core/agent/cot_agent_runner.py
#	api/core/agent/fc_agent_runner.py
#	api/core/app/apps/workflow_app_runner.py
#	api/core/app/entities/queue_entities.py
#	api/core/app/entities/task_entities.py
#	api/core/workflow/workflow_entry.py
#	api/dify_graph/enums.py
#	api/dify_graph/graph/graph.py
#	api/dify_graph/graph_events/node.py
#	api/dify_graph/model_runtime/entities/message_entities.py
#	api/dify_graph/node_events/node.py
#	api/dify_graph/nodes/agent/agent_node.py
#	api/dify_graph/nodes/base/__init__.py
#	api/dify_graph/nodes/base/entities.py
#	api/dify_graph/nodes/base/node.py
#	api/dify_graph/nodes/llm/entities.py
#	api/dify_graph/nodes/llm/node.py
#	api/dify_graph/nodes/tool/tool_node.py
#	api/pyproject.toml
#	api/uv.lock
#	web/app/components/base/avatar/__tests__/index.spec.tsx
#	web/app/components/base/avatar/index.tsx
#	web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx
#	web/app/components/base/file-uploader/file-from-link-or-local/index.tsx
#	web/app/components/base/prompt-editor/index.tsx
#	web/app/components/datasets/metadata/edit-metadata-batch/modal.tsx
#	web/app/components/header/account-dropdown/index.spec.tsx
#	web/app/components/share/text-generation/index.tsx
#	web/app/components/workflow/block-selector/tool/action-item.tsx
#	web/app/components/workflow/block-selector/trigger-plugin/action-item.tsx
#	web/app/components/workflow/hooks/use-edges-interactions.ts
#	web/app/components/workflow/hooks/use-nodes-interactions.ts
#	web/app/components/workflow/index.tsx
#	web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx
#	web/app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx
#	web/app/components/workflow/nodes/human-input/components/delivery-method/recipient/email-item.tsx
#	web/app/components/workflow/nodes/loop/use-interactions.ts
#	web/contract/router.ts
#	web/env.ts
#	web/eslint-suppressions.json
#	web/package.json
#	web/pnpm-lock.yaml
This commit is contained in:
Novice
2026-03-23 10:52:06 +08:00
1395 changed files with 167201 additions and 73658 deletions

View File

@ -1,4 +1 @@
from .node_factory import DifyNodeFactory
from .workflow_entry import WorkflowEntry
__all__ = ["DifyNodeFactory", "WorkflowEntry"]
"""Core workflow package."""

View File

@ -1,5 +1,8 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
import importlib
import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -8,7 +11,6 @@ from typing_extensions import override
from configs import dify_config
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.llm.model_access import build_dify_model_access
from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
@ -17,39 +19,37 @@ from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.entities.graph_config import NodeConfigDict
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
from core.workflow.nodes.agent.plugin_strategy_adapter import (
PluginAgentStrategyPresentationProvider,
PluginAgentStrategyResolver,
)
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import NodeType, SystemVariableKey
from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey
from dify_graph.file.file_manager import file_manager
from dify_graph.graph.graph import NodeFactory
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
from dify_graph.nodes.code.entities import CodeLanguage
from dify_graph.nodes.code.limits import CodeNodeLimits
from dify_graph.nodes.datasource import DatasourceNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from dify_graph.nodes.llm.entities import ModelConfig
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
from dify_graph.nodes.http_request import build_http_request_config
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
from dify_graph.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
@ -58,6 +58,138 @@ if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
LATEST_VERSION = "latest"
_START_NODE_TYPES: frozenset[NodeType] = frozenset(
(BuiltinNodeTypes.START, BuiltinNodeTypes.DATASOURCE, *TRIGGER_NODE_TYPES)
)
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
if module_name in excluded_modules:
continue
importlib.import_module(module_name)
@lru_cache(maxsize=1)
def register_nodes() -> None:
"""Import production node modules so they self-register with ``Node``."""
_import_node_package("dify_graph.nodes")
_import_node_package("core.workflow.nodes")
def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return a read-only snapshot of the current production node registry.
The workflow layer owns node bootstrap because it must compose built-in
`dify_graph.nodes.*` implementations with workflow-local nodes under
`core.workflow.nodes.*`. Keeping this import side effect here avoids
reintroducing registry bootstrapping into lower-level graph primitives.
"""
register_nodes()
return Node.get_node_type_classes_mapping()
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
return node_class
def is_start_node_type(node_type: NodeType) -> bool:
"""Return True when the node type can serve as a workflow entry point."""
return node_type in _START_NODE_TYPES
def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str:
"""Resolve the default entry node for a persisted top-level workflow graph.
This workflow-layer helper depends on start-node semantics defined by
`is_start_node_type`, so it intentionally lives next to the node registry
instead of in the raw `dify_graph.entities.graph_config` schema module.
"""
nodes = graph_config.get("nodes")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
for node in nodes:
if not isinstance(node, Mapping):
continue
if node.get("type") == "custom-note":
continue
node_id = node.get("id")
data = node.get("data")
if not isinstance(node_id, str) or not isinstance(data, Mapping):
continue
node_type = data.get("type")
if isinstance(node_type, str) and is_start_node_type(node_type):
return node_id
raise ValueError("Unable to determine default root node ID from workflow graph")
class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Node]]]):
"""Mutable dict-like view over the current node registry."""
def __init__(self) -> None:
self._cached_snapshot: dict[NodeType, Mapping[str, type[Node]]] = {}
self._cached_version = -1
self._deleted: set[NodeType] = set()
self._overrides: dict[NodeType, Mapping[str, type[Node]]] = {}
def _snapshot(self) -> dict[NodeType, Mapping[str, type[Node]]]:
current_version = Node.get_registry_version()
if self._cached_version != current_version:
self._cached_snapshot = dict(get_node_type_classes_mapping())
self._cached_version = current_version
if not self._deleted and not self._overrides:
return self._cached_snapshot
snapshot = {key: value for key, value in self._cached_snapshot.items() if key not in self._deleted}
snapshot.update(self._overrides)
return snapshot
def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]:
return self._snapshot()[key]
def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None:
self._deleted.discard(key)
self._overrides[key] = value
def __delitem__(self, key: NodeType) -> None:
if key in self._overrides:
del self._overrides[key]
return
if key in self._cached_snapshot:
self._deleted.add(key)
return
raise KeyError(key)
def __iter__(self) -> Iterator[NodeType]:
return iter(self._snapshot())
def __len__(self) -> int:
return len(self._snapshot())
# Keep the canonical node-class mapping in the workflow layer that also bootstraps
# legacy `core.workflow.nodes.*` registrations.
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
def fetch_memory(
*,
@ -99,10 +231,7 @@ class DefaultWorkflowCodeExecutor:
@final
class DifyNodeFactory(NodeFactory):
"""
Default implementation of NodeFactory that uses the traditional node mapping.
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
and instantiating the appropriate node class.
Default implementation of NodeFactory that resolves node classes from the live registry.
"""
def __init__(
@ -129,7 +258,6 @@ class DifyNodeFactory(NodeFactory):
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
self._http_request_file_manager = file_manager
self._rag_retrieval = DatasetRetrieval()
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
@ -145,6 +273,10 @@ class DifyNodeFactory(NodeFactory):
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
self._agent_strategy_resolver = PluginAgentStrategyResolver()
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
self._agent_runtime_support = AgentRuntimeSupport()
self._agent_message_transformer = AgentMessageTransformer()
@staticmethod
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
@ -156,167 +288,115 @@ class DifyNodeFactory(NodeFactory):
return DifyRunContext.model_validate(raw_ctx)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
:raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
# Get node_id from config
node_id = node_config["id"]
# Get node type from config
node_data = node_config["data"]
try:
node_type = NodeType(node_data["type"])
except ValueError:
raise ValueError(f"Unknown node type: {node_data['type']}")
# Get node class
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
node_version = str(node_data.get("version", "1"))
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
max_output_length=self._template_transform_max_output_length,
)
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_request_config=self._http_request_config,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.HUMAN_INPUT:
return HumanInputNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
)
if node_type == NodeType.KNOWLEDGE_INDEX:
return KnowledgeIndexNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
index_processor=IndexProcessor(),
summary_index_service=SummaryIndex(),
)
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return LLMNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.DATASOURCE:
return DatasourceNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
datasource_manager=DatasourceManager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
rag_retrieval=self._rag_retrieval,
)
if node_type == NodeType.DOCUMENT_EXTRACTOR:
return DocumentExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
unstructured_api_config=self._document_extractor_unstructured_api_config,
http_client=self._http_request_http_client,
)
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
"code_executor": self._code_executor,
"code_limits": self._code_limits,
},
BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: {
"template_renderer": self._template_renderer,
"max_output_length": self._template_transform_max_output_length,
},
BuiltinNodeTypes.HTTP_REQUEST: lambda: {
"http_request_config": self._http_request_config,
"http_client": self._http_request_http_client,
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
"file_manager": self._http_request_file_manager,
},
BuiltinNodeTypes.HUMAN_INPUT: lambda: {
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
"unstructured_api_config": self._document_extractor_unstructured_api_config,
"http_client": self._http_request_http_client,
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=False,
),
BuiltinNodeTypes.TOOL: lambda: {
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
},
BuiltinNodeTypes.AGENT: lambda: {
"strategy_resolver": self._agent_strategy_resolver,
"presentation_provider": self._agent_strategy_presentation_provider,
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
},
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
)
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
node_data_model = ModelConfig.model_validate(node_data["model"])
@staticmethod
def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData:
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
return node_class.validate_node_data(node_data)
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
def _build_llm_compatible_node_init_kwargs(
self,
*,
node_class: type[Node],
node_data: BaseNodeData,
include_http_client: bool,
) -> dict[str, object]:
validated_node_data = cast(
LLMCompatibleNodeData,
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
)
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,
"model_factory": self._llm_model_factory,
"model_instance": model_instance,
"memory": self._build_memory_for_llm_node(
node_data=validated_node_data,
model_instance=model_instance,
),
}
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
return node_init_kwargs
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
node_data_model = node_data.model
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
@ -352,14 +432,12 @@ class DifyNodeFactory(NodeFactory):
def _build_memory_for_llm_node(
self,
*,
node_data: Mapping[str, Any],
node_data: LLMCompatibleNodeData,
model_instance: ModelInstance,
) -> PromptMessageMemory | None:
raw_memory_config = node_data.get("memory")
if raw_memory_config is None:
if node_data.memory is None:
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
@ -369,6 +447,6 @@ class DifyNodeFactory(NodeFactory):
return fetch_memory(
conversation_id=conversation_id,
app_id=self._dify_context.app_id,
node_data_memory=node_memory,
node_data_memory=node_data.memory,
model_instance=model_instance,
)

View File

@ -0,0 +1 @@
"""Workflow node implementations that remain under the legacy core.workflow namespace."""

View File

@ -0,0 +1,4 @@
from .agent_node import AgentNode
from .entities import AgentNodeData
__all__ = ["AgentNode", "AgentNodeData"]

View File

@ -0,0 +1,188 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from .entities import AgentNodeData
from .exceptions import (
AgentInvocationError,
AgentMessageTransformError,
)
from .message_transformer import AgentMessageTransformer
from .runtime_support import AgentRuntimeSupport
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
class AgentNode(Node[AgentNodeData]):
node_type = BuiltinNodeTypes.AGENT
_strategy_resolver: AgentStrategyResolver
_presentation_provider: AgentStrategyPresentationProvider
_runtime_support: AgentRuntimeSupport
_message_transformer: AgentMessageTransformer
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._strategy_resolver = strategy_resolver
self._presentation_provider = presentation_provider
self._runtime_support = runtime_support
self._message_transformer = message_transformer
@classmethod
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
dify_ctx = self.require_dify_context()
event.extras["agent_strategy"] = {
"name": self.node_data.agent_strategy_name,
"icon": self._presentation_provider.get_icon(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
),
}
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = self._strategy_resolver.resolve(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
parameters = self._runtime_support.build_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
)
parameters_for_log = self._runtime_support.build_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
for_log=True,
)
credentials = self._runtime_support.build_credentials(parameters=parameters)
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
try:
yield from self._message_transformer.transform(
messages=message_stream,
tool_info={
"icon": self._presentation_provider.get_icon(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
),
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
result: dict[str, Any] = {}
typed_node_data = node_data
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result

View File

@ -0,0 +1,47 @@
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
agent_strategy_provider_name: str
agent_strategy_name: str
agent_strategy_label: str
memory: MemoryConfig | None = None
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version
# and requires using the legacy parameter parsing rules.
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(IntEnum):
CLOSE = 0
OPEN = 1
class AgentOldVersionModelFeatures(StrEnum):
"""
Enum class for old SDK version llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

View File

@ -0,0 +1,121 @@
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: str | None = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
strategy_name,
provider_name,
)
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: str | None = None):
self.parameter_name = parameter_name
super().__init__(message)
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: str | None = None):
self.variable_name = variable_name
super().__init__(message)
class AgentVariableNotFoundError(AgentVariableError):
"""Exception raised when a variable is not found in the variable pool."""
def __init__(self, variable_name: str):
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
class AgentInputTypeError(AgentNodeError):
"""Exception raised when an unknown agent input type is encountered."""
def __init__(self, input_type: str):
super().__init__(f"Unknown agent input type '{input_type}'")
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: str | None = None):
self.file_id = file_id
super().__init__(message)
class ToolFileNotFoundError(ToolFileError):
"""Exception raised when a tool file is not found."""
def __init__(self, file_id: str):
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: str | None = None):
self.conversation_id = conversation_id
super().__init__(message)
class AgentVariableTypeError(AgentNodeError):
"""Exception raised when a variable has an unexpected type."""
def __init__(
self,
message: str,
variable_name: str | None = None,
expected_type: str | None = None,
actual_type: str | None = None,
):
self.variable_name = variable_name
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)

View File

@ -0,0 +1,292 @@
from __future__ import annotations
from collections.abc import Generator, Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from dify_graph.variables.segments import ArrayFileSegment
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
class AgentMessageTransformer:
def transform(
self,
*,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == BuiltinNodeTypes.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
for log in agent_logs:
if log.message_id == agent_log.message_id:
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
json_output: list[dict[str, Any] | list[Any]] = []
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from factories.agent_factory import get_plugin_agent_strategy
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
class PluginAgentStrategyResolver(AgentStrategyResolver):
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> ResolvedAgentStrategy:
return get_plugin_agent_strategy(
tenant_id=tenant_id,
agent_strategy_provider_name=agent_strategy_provider_name,
agent_strategy_name=agent_strategy_name,
)
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
try:
plugins = manager.list_plugins(tenant_id)
except Exception:
return None
try:
current_plugin = next(
plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name
)
except StopIteration:
return None
return current_plugin.declaration.icon

View File

@ -0,0 +1,276 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from dify_graph.enums import SystemVariableKey
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
class AgentRuntimeSupport:
def build_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: Any,
for_log: bool = False,
) -> dict[str, Any]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id,
app_id,
entity,
invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
memory = self.fetch_memory(
variable_pool=variable_pool,
app_id=app_id,
model_instance=model_instance,
)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
return credentials
def fetch_memory(
self,
*,
variable_pool: VariablePool,
app_id: str,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
model_type=ModelType.LLM,
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
@staticmethod
def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]:
try:
AgentOldVersionModelFeatures(feature.value)
except ValueError:
model_schema.features.remove(feature)
return model_schema
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]

View File

@ -0,0 +1,39 @@
from __future__ import annotations
from collections.abc import Generator, Sequence
from typing import Any, Protocol
from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.entities.request import InvokeCredentials
from core.tools.entities.tool_entities import ToolInvokeMessage
class ResolvedAgentStrategy(Protocol):
meta_version: str | None
def get_parameters(self) -> Sequence[AgentStrategyParameter]: ...
def invoke(
self,
*,
params: dict[str, Any],
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
credentials: InvokeCredentials | None = None,
) -> Generator[ToolInvokeMessage, None, None]: ...
class AgentStrategyResolver(Protocol):
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> ResolvedAgentStrategy: ...
class AgentStrategyPresentationProvider(Protocol):
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ...

View File

@ -0,0 +1 @@
"""Datasource workflow node package."""

View File

@ -0,0 +1,215 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam
from .exc import DatasourceNodeError
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
class DatasourceNode(Node[DatasourceNodeData]):
"""
Datasource Node
"""
node_type = BuiltinNodeTypes.DATASOURCE
execution_type = NodeExecutionType.ROOT
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self.datasource_manager = DatasourceManager
def populate_start_event(self, event) -> None:
event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}"
event.provider_type = self.node_data.provider_type
def _run(self) -> Generator:
"""
Run the datasource node
"""
dify_ctx = self.require_dify_context()
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segment:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None
datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segment:
raise DatasourceNodeError("Datasource info is not set")
datasource_info_value = datasource_info_segment.value
if not isinstance(datasource_info_value, dict):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = DatasourceProviderType.value_of(datasource_type)
provider_id = f"{node_data.plugin_id}/{node_data.provider_name}"
datasource_info["icon"] = self.datasource_manager.get_icon_url(
provider_id=provider_id,
datasource_name=node_data.datasource_name or "",
tenant_id=dify_ctx.tenant_id,
datasource_type=datasource_type.value,
)
parameters_for_log = datasource_info
try:
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE:
# Build typed request objects
datasource_parameters = None
if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_parameters = DatasourceParameter(
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
)
online_drive_request = None
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
online_drive_request = OnlineDriveDownloadFileParam(
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket", ""),
)
credential_id = datasource_info.get("credential_id", "")
yield from self.datasource_manager.stream_node_events(
node_id=self._node_id,
user_id=dify_ctx.user_id,
datasource_name=node_data.datasource_name or "",
datasource_type=datasource_type.value,
provider_id=provider_id,
tenant_id=dify_ctx.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=credential_id,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
variable_pool=variable_pool,
datasource_param=datasource_parameters,
online_drive_request=online_drive_request,
)
case DatasourceProviderType.WEBSITE_CRAWL:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
**datasource_info,
"datasource_type": datasource_type,
},
)
)
case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
file_info = self.datasource_manager.get_upload_file_by_id(
file_id=related_id, tenant_id=dify_ctx.tenant_id
)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file_info,
"datasource_type": datasource_type,
},
)
)
case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
)
except DatasourceNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DatasourceNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
result = {}
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
input = node_data.datasource_parameters[parameter_name]
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
case "constant":
pass
case None:
pass
result = {node_id + "." + key: value for key, value in result.items()}
return result
@classmethod
def version(cls) -> str:
return "1"

View File

@ -0,0 +1,55 @@
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
class DatasourceEntity(BaseModel):
plugin_id: str
provider_name: str # redundancy
provider_type: str
datasource_name: str | None = "local_file"
datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
type: NodeType = BuiltinNodeTypes.DATASOURCE
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"] | None = None
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
datasource_parameters: dict[str, DatasourceInput] | None = None
class DatasourceParameter(BaseModel):
workspace_id: str
page_id: str
type: str
class OnlineDriveDownloadFileParam(BaseModel):
id: str
bucket: str

View File

@ -0,0 +1,16 @@
class DatasourceNodeError(ValueError):
"""Base exception for datasource node errors."""
pass
class DatasourceParameterError(DatasourceNodeError):
"""Exception raised for errors in datasource parameters."""
pass
class DatasourceFileError(DatasourceNodeError):
"""Exception raised for errors related to datasource files."""
pass

View File

@ -0,0 +1,35 @@
from collections.abc import Generator
from typing import Any, Protocol
from dify_graph.file import File
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
from .entities import DatasourceParameter, OnlineDriveDownloadFileParam
class DatasourceManagerProtocol(Protocol):
@classmethod
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: ...
@classmethod
def stream_node_events(
cls,
*,
node_id: str,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: Any,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: ...
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: ...

View File

@ -0,0 +1,5 @@
"""Knowledge index workflow node package."""
KNOWLEDGE_INDEX_NODE_TYPE = "knowledge-index"
__all__ = ["KNOWLEDGE_INDEX_NODE_TYPE"]

View File

@ -0,0 +1,164 @@
from typing import Literal, Union
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str
reranking_model_name: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: RetrievalMethod
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class FileInfo(BaseModel):
"""
File Info.
"""
file_id: str
class OnlineDocumentIcon(BaseModel):
"""
Document Icon.
"""
icon_url: str
icon_type: str
icon_emoji: str
class OnlineDocumentInfo(BaseModel):
"""
Online document info.
"""
provider: str
workspace_id: str | None = None
page_id: str
page_type: str
icon: OnlineDocumentIcon | None = None
class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
url: str
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunks: list[str]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_contents: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class KnowledgeIndexNodeData(BaseNodeData):
"""
Knowledge index Node Data.
"""
type: NodeType = KNOWLEDGE_INDEX_NODE_TYPE
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None

View File

@ -0,0 +1,22 @@
class KnowledgeIndexNodeError(ValueError):
"""Base class for KnowledgeIndexNode errors."""
class ModelNotExistError(KnowledgeIndexNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeIndexNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeIndexNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeIndexNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -0,0 +1,153 @@
import logging
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, SystemVariableKey
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.template import Template
from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
)
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
_INVOKE_FROM_DEBUGGER = "debugger"
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
node_type = KNOWLEDGE_INDEX_NODE_TYPE
execution_type = NodeExecutionType.RESPONSE
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(id, config, graph_init_params, graph_runtime_state)
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()
def _run(self) -> NodeRunResult: # type: ignore
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
# get dataset id as string
dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id_segment:
raise KnowledgeIndexNodeError("Dataset ID is required.")
dataset_id: str = dataset_id_segment.value
# get document id as string (may be empty when not provided)
document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
document_id: str = document_id_segment.value if document_id_segment else ""
# extract variables
variable = variable_pool.get(node_data.index_chunk_variable_selector)
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
invoke_from_value = str(invoke_from.value) if invoke_from else None
is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
try:
summary_index_setting = node_data.summary_index_setting
if is_preview:
# Preview mode: generate summaries for chunks directly without saving to database
# Format preview and generate summaries on-the-fly
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
# or fallback to dataset if not available in node_data
outputs = self.index_processor.get_preview_output(
chunks, dataset_id, document_id, node_data.chunk_structure, summary_index_setting
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=outputs.model_dump(exclude_none=True),
)
original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
results = self._invoke_knowledge_index(
dataset_id=dataset_id,
document_id=document_id,
original_document_id=original_document_id_segment.value if original_document_id_segment else "",
is_preview=is_preview,
batch=batch.value,
chunks=chunks,
summary_index_setting=summary_index_setting,
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node", exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
except Exception as e:
logger.error(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _invoke_knowledge_index(
self,
dataset_id: str,
document_id: str,
original_document_id: str,
is_preview: bool,
batch: Any,
chunks: Mapping[str, Any],
summary_index_setting: dict | None = None,
):
if not document_id:
raise KnowledgeIndexNodeError("document_id is required.")
rst = self.index_processor.index_and_clean(
dataset_id, document_id, original_document_id, chunks, batch, summary_index_setting
)
self.summary_index_service.generate_and_vectorize_summary(
dataset_id, document_id, is_preview, summary_index_setting
)
return rst
@classmethod
def version(cls) -> str:
return "1"
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this knowledge index node
"""
return Template(segments=[])

View File

@ -0,0 +1,47 @@
from collections.abc import Mapping
from typing import Any, Protocol
from pydantic import BaseModel, Field
class PreviewItem(BaseModel):
content: str | None = Field(default=None)
child_chunks: list[str] | None = Field(default=None)
summary: str | None = Field(default=None)
class QaPreview(BaseModel):
answer: str | None = Field(default=None)
question: str | None = Field(default=None)
class Preview(BaseModel):
chunk_structure: str
parent_mode: str | None = Field(default=None)
preview: list[PreviewItem] = Field(default_factory=list)
qa_preview: list[QaPreview] = Field(default_factory=list)
total_segments: int
class IndexProcessorProtocol(Protocol):
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview: ...
def index_and_clean(
self,
dataset_id: str,
document_id: str,
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
) -> dict[str, Any]: ...
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
) -> Preview: ...
class SummaryIndexServiceProtocol(Protocol):
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
) -> None: ...

View File

@ -0,0 +1 @@
"""Knowledge retrieval workflow node package."""

View File

@ -0,0 +1,136 @@
from collections.abc import Sequence
from typing import Literal
from pydantic import BaseModel, Field
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
provider: str
model: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class MultipleRetrievalConfig(BaseModel):
"""
Multiple Retrieval Config.
"""
top_k: int
score_threshold: float | None = None
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
"""
model: ModelConfig
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: NodeType = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
query_variable_selector: list[str] | None | str = None
query_attachment_selector: list[str] | None | str = None
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: MultipleRetrievalConfig | None = None
single_retrieval_config: SingleRetrievalConfig | None = None
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property
def structured_output_enabled(self) -> bool:
# NOTE(QuantumGhost): Temporary workaround for issue #20725
# (https://github.com/langgenius/dify/issues/20725).
#
# The proper fix would be to make `KnowledgeRetrievalNode` inherit
# from `BaseNode` instead of `LLMNode`.
return False

View File

@ -0,0 +1,26 @@
class KnowledgeRetrievalNodeError(ValueError):
"""Base class for KnowledgeRetrievalNode errors."""
class ModelNotExistError(KnowledgeRetrievalNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeRetrievalNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""
class RateLimitExceededError(KnowledgeRetrievalNodeError):
"""Raised when the rate limit is exceeded."""

View File

@ -0,0 +1,319 @@
"""Knowledge retrieval workflow node implementation.
This node now lives under ``core.workflow.nodes`` and is discovered directly by
the workflow node registry.
"""
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
)
from dify_graph.variables.segments import ArrayObjectSegment
from .entities import (
Condition,
KnowledgeRetrievalNodeData,
MetadataFilteringCondition,
)
from .exc import (
KnowledgeRetrievalNodeError,
RateLimitExceededError,
)
from .retrieval import KnowledgeRetrievalRequest, Source
if TYPE_CHECKING:
from dify_graph.file.models import File
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._rag_retrieval = DatasetRetrieval()
@classmethod
def version(cls):
return "1"
def _run(self) -> NodeRunResult:
usage = LLMUsage.empty_usage()
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={},
llm_usage=usage,
)
variables: dict[str, Any] = {}
# extract variables
if self._node_data.query_variable_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables["query"] = query
if self._node_data.query_attachment_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector)
if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Attachments variable is not array file or file type.",
)
if isinstance(variable, ArrayFileSegment):
variables["attachments"] = variable.value
else:
variables["attachments"] = [variable.value]
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data={"usage": jsonable_encoder(usage)},
outputs=outputs, # type: ignore
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except RateLimitExceededError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node", exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[Source], LLMUsage]:
dify_ctx = self.require_dify_context()
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
retrieval_resource_list = []
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
resolved_metadata_conditions = (
self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions)
if node_data.metadata_filtering_conditions
else None
)
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required for single retrieval mode")
model = node_data.single_retrieval_config.model
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
user_from=dify_ctx.user_from.value,
dataset_ids=dataset_ids,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
completion_params=model.completion_params,
model_provider=model.provider,
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
)
)
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
reranking_model = None
weights = None
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
case "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
case _:
# Handle any other reranking_mode values
reranking_model = None
weights = None
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
app_id=dify_ctx.app_id,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
user_from=dify_ctx.user_from.value,
dataset_ids=dataset_ids,
query=query,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
)
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
def _resolve_metadata_filtering_conditions(
self, conditions: MetadataFilteringCondition
) -> MetadataFilteringCondition:
if conditions.conditions is None:
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator,
conditions=None,
)
variable_pool = self.graph_runtime_state.variable_pool
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = segment_group.value[0].to_object()
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values = []
for v in value: # type: ignore
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(segment_group.value[0].to_object())
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
else:
resolved_value = value
resolved_conditions.append(
Condition(
name=cond.name,
comparison_operator=cond.comparison_operator,
value=resolved_value,
)
)
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator or "and",
conditions=resolved_conditions,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
variable_mapping = {}
if node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
if node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector
return variable_mapping

View File

@ -0,0 +1,88 @@
from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from dify_graph.model_runtime.entities import LLMUsage
from dify_graph.nodes.llm.entities import ModelConfig
from .entities import MetadataFilteringCondition
class SourceChildChunk(BaseModel):
id: str = Field(default="", description="Child chunk ID")
content: str = Field(default="", description="Child chunk content")
position: int = Field(default=0, description="Child chunk position")
score: float = Field(default=0.0, description="Child chunk relevance score")
class SourceMetadata(BaseModel):
source: str = Field(
default="knowledge",
serialization_alias="_source",
description="Data source identifier",
)
dataset_id: str = Field(description="Dataset unique identifier")
dataset_name: str = Field(description="Dataset display name")
document_id: str = Field(description="Document unique identifier")
document_name: str = Field(description="Document display name")
data_source_type: str = Field(description="Type of data source")
segment_id: str | None = Field(default=None, description="Segment unique identifier")
retriever_from: str = Field(default="workflow", description="Retriever source context")
score: float = Field(default=0.0, description="Retrieval relevance score")
child_chunks: list[SourceChildChunk] = Field(default_factory=list, description="List of child chunks")
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
segment_position: int | None = Field(default=0, description="Position of segment in document")
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
position: int | None = Field(default=0, description="Position of the document in the dataset")
class Config:
populate_by_name = True
class Source(BaseModel):
metadata: SourceMetadata = Field(description="Source metadata information")
title: str = Field(description="Document title")
files: list[Any] | None = Field(default=None, description="Associated file references")
content: str | None = Field(description="Segment content text")
summary: str | None = Field(default=None, description="Content summary if available")
class KnowledgeRetrievalRequest(BaseModel):
tenant_id: str = Field(description="Tenant unique identifier")
user_id: str = Field(description="User unique identifier")
app_id: str = Field(description="Application unique identifier")
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
completion_params: dict[str, Any] | None = Field(
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
)
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
metadata_model_config: ModelConfig | None = Field(
default=None, description="Model config for metadata-based filtering"
)
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
default=None, description="Conditions for filtering by metadata"
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
)
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
class RAGRetrievalProtocol(Protocol):
@property
def llm_usage(self) -> LLMUsage: ...
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]: ...

View File

@ -0,0 +1,66 @@
METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "", ">", "<", "", "", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501
METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which companys email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""
METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""
METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""
METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which companys email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

View File

@ -0,0 +1,3 @@
from .trigger_event_node import TriggerEventNode
__all__ = ["TriggerEventNode"]

View File

@ -0,0 +1,80 @@
from collections.abc import Mapping
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.entities.entities import EventParameter
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from .exc import TriggerEventParameterError
class TriggerEventNodeData(BaseNodeData):
"""Plugin trigger node data"""
type: NodeType = TRIGGER_PLUGIN_NODE_TYPE
class TriggerEventInput(BaseModel):
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
type = value
value = validation_info.data.get("value")
if value is None:
return type
if type == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
if type == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
if type == "constant" and not isinstance(value, str | int | float | bool | dict | list):
raise ValueError("value must be a string, int, float, bool or dict")
return type
plugin_id: str = Field(..., description="Plugin ID")
provider_id: str = Field(..., description="Provider ID")
event_name: str = Field(..., description="Event name")
subscription_id: str = Field(..., description="Subscription ID")
plugin_unique_identifier: str = Field(..., description="Plugin unique identifier")
event_parameters: Mapping[str, TriggerEventInput] = Field(default_factory=dict, description="Trigger parameters")
def resolve_parameters(
self,
*,
parameter_schemas: Mapping[str, EventParameter],
) -> Mapping[str, Any]:
"""
Generate parameters based on the given plugin trigger parameters.
Args:
parameter_schemas (Mapping[str, EventParameter]): The mapping of parameter schemas.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
result: dict[str, Any] = {}
for parameter_name in self.event_parameters:
parameter: EventParameter | None = parameter_schemas.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
event_input = self.event_parameters[parameter_name]
# trigger node only supports constant input
if event_input.type != "constant":
raise TriggerEventParameterError(f"Unknown plugin trigger input type '{event_input.type}'")
result[parameter_name] = event_input.value
return result

View File

@ -0,0 +1,10 @@
class TriggerEventNodeError(ValueError):
"""Base exception for plugin trigger node errors."""
pass
class TriggerEventParameterError(TriggerEventNodeError):
"""Exception raised for errors in plugin trigger parameters."""
pass

View File

@ -0,0 +1,69 @@
from collections.abc import Mapping
from typing import Any, cast
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from .entities import TriggerEventNodeData
class TriggerEventNode(Node[TriggerEventNodeData]):
node_type = TRIGGER_PLUGIN_NODE_TYPE
execution_type = NodeExecutionType.ROOT
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "plugin",
"config": {
"title": "",
"plugin_id": "",
"provider_id": "",
"event_name": "",
"subscription_id": "",
"plugin_unique_identifier": "",
"event_parameters": {},
},
}
@classmethod
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
event.provider_id = self.node_data.provider_id
def _run(self) -> NodeRunResult:
"""
Run the plugin trigger node.
This node invokes the trigger to convert request data into events
and makes them available to downstream nodes.
"""
# Get trigger data passed when workflow was triggered
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
"provider_id": self.node_data.provider_id,
"event_name": self.node_data.event_name,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
},
}
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=outputs,
metadata=metadata,
)

View File

@ -0,0 +1,3 @@
from .trigger_schedule_node import TriggerScheduleNode
__all__ = ["TriggerScheduleNode"]

View File

@ -0,0 +1,52 @@
from typing import Literal, Union
from pydantic import BaseModel, Field
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class TriggerScheduleNodeData(BaseNodeData):
"""
Trigger Schedule Node Data
"""
type: NodeType = TRIGGER_SCHEDULE_NODE_TYPE
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
visual_config: dict | None = Field(default=None, description="Visual configuration details")
timezone: str = Field(default="UTC", description="Timezone for schedule execution")
class ScheduleConfig(BaseModel):
node_id: str
cron_expression: str
timezone: str = "UTC"
class SchedulePlanUpdate(BaseModel):
node_id: str | None = None
cron_expression: str | None = None
timezone: str | None = None
class VisualConfig(BaseModel):
"""Visual configuration for schedule trigger"""
# For hourly frequency
on_minute: int | None = Field(default=0, ge=0, le=59, description="Minute of the hour (0-59)")
# For daily, weekly, monthly frequencies
time: str | None = Field(default="12:00 AM", description="Time in 12-hour format (e.g., '2:30 PM')")
# For weekly frequency
weekdays: list[Literal["sun", "mon", "tue", "wed", "thu", "fri", "sat"]] | None = Field(
default=None, description="List of weekdays to run on"
)
# For monthly frequency
monthly_days: list[Union[int, Literal["last"]]] | None = Field(
default=None, description="Days of month to run on (1-31 or 'last')"
)

View File

@ -0,0 +1,31 @@
from dify_graph.entities.exc import BaseNodeError
class ScheduleNodeError(BaseNodeError):
"""Base schedule node error."""
pass
class ScheduleNotFoundError(ScheduleNodeError):
"""Schedule not found error."""
pass
class ScheduleConfigError(ScheduleNodeError):
"""Schedule configuration error."""
pass
class ScheduleExecutionError(ScheduleNodeError):
"""Schedule execution error."""
pass
class TenantOwnerNotFoundError(ScheduleExecutionError):
"""Tenant owner not found error for schedule execution."""
pass

View File

@ -0,0 +1,46 @@
from collections.abc import Mapping
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from .entities import TriggerScheduleNodeData
class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
node_type = TRIGGER_SCHEDULE_NODE_TYPE
execution_type = NodeExecutionType.ROOT
@classmethod
def version(cls) -> str:
return "1"
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": TRIGGER_SCHEDULE_NODE_TYPE,
"config": {
"mode": "visual",
"frequency": "daily",
"visual_config": {"time": "12:00 AM", "on_minute": 0, "weekdays": ["sun"], "monthly_days": [1]},
"timezone": "UTC",
},
}
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=outputs,
)

View File

@ -0,0 +1,3 @@
from .node import TriggerWebhookNode
__all__ = ["TriggerWebhookNode"]

View File

@ -0,0 +1,133 @@
from collections.abc import Sequence
from enum import StrEnum
from pydantic import BaseModel, Field, field_validator
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.variables.types import SegmentType
_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
}
)
_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
}
)
_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES
_WEBHOOK_BODY_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_BOOLEAN,
SegmentType.ARRAY_OBJECT,
SegmentType.FILE,
}
)
class Method(StrEnum):
GET = "get"
POST = "post"
HEAD = "head"
PATCH = "patch"
PUT = "put"
DELETE = "delete"
class ContentType(StrEnum):
JSON = "application/json"
FORM_DATA = "multipart/form-data"
FORM_URLENCODED = "application/x-www-form-urlencoded"
TEXT = "text/plain"
BINARY = "application/octet-stream"
class WebhookParameter(BaseModel):
"""Parameter definition for headers or query params."""
name: str
type: SegmentType = SegmentType.STRING
required: bool = False
@field_validator("type", mode="after")
@classmethod
def validate_type(cls, v: SegmentType) -> SegmentType:
if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook parameter type: {v}")
return v
class WebhookBodyParameter(BaseModel):
"""Body parameter with type information."""
name: str
type: SegmentType = SegmentType.STRING
required: bool = False
@field_validator("type", mode="after")
@classmethod
def validate_type(cls, v: SegmentType) -> SegmentType:
if v not in _WEBHOOK_BODY_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook body parameter type: {v}")
return v
class WebhookData(BaseNodeData):
"""
Webhook Node Data.
"""
class SyncMode(StrEnum):
SYNC = "async" # only support
type: NodeType = TRIGGER_WEBHOOK_NODE_TYPE
method: Method = Method.GET
content_type: ContentType = Field(default=ContentType.JSON)
headers: Sequence[WebhookParameter] = Field(default_factory=list)
params: Sequence[WebhookParameter] = Field(default_factory=list) # query parameters
body: Sequence[WebhookBodyParameter] = Field(default_factory=list)
@field_validator("method", mode="before")
@classmethod
def normalize_method(cls, v) -> str:
"""Normalize HTTP method to lowercase to support both uppercase and lowercase input."""
if isinstance(v, str):
return v.lower()
return v
@field_validator("headers", mode="after")
@classmethod
def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
for param in v:
if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook header parameter type: {param.type}")
return v
@field_validator("params", mode="after")
@classmethod
def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
for param in v:
if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook query parameter type: {param.type}")
return v
status_code: int = 200 # Expected status code for response
response_body: str = "" # Template for response body
# Webhook specific fields (not from client data, set internally)
webhook_id: str | None = None # Set when webhook trigger is created
timeout: int = 30 # Timeout in seconds to wait for webhook response

View File

@ -0,0 +1,25 @@
from dify_graph.entities.exc import BaseNodeError
class WebhookNodeError(BaseNodeError):
"""Base webhook node error."""
pass
class WebhookTimeoutError(WebhookNodeError):
"""Webhook timeout error."""
pass
class WebhookNotFoundError(WebhookNodeError):
"""Webhook not found error."""
pass
class WebhookConfigError(WebhookNodeError):
"""Webhook configuration error."""
pass

View File

@ -0,0 +1,177 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType
from dify_graph.file import FileTransferMethod
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.variables.types import SegmentType
from dify_graph.variables.variables import FileVariable
from factories import file_factory
from factories.variable_factory import build_segment_with_type
from .entities import ContentType, WebhookData
logger = logging.getLogger(__name__)
class TriggerWebhookNode(Node[WebhookData]):
node_type = TRIGGER_WEBHOOK_NODE_TYPE
execution_type = NodeExecutionType.ROOT
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "webhook",
"config": {
"method": "get",
"content_type": "application/json",
"headers": [],
"params": [],
"body": [],
"async_mode": True,
"status_code": 200,
"response_body": "",
"timeout": 30,
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the webhook node.
Like the start node, this simply takes the webhook data from the variable pool
and makes it available to downstream nodes. The actual webhook handling
happens in the trigger controller.
"""
# Get webhook data from variable pool (injected by Celery task)
webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
# Extract webhook-specific outputs based on node configuration
outputs = self._extract_configured_outputs(webhook_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=webhook_inputs,
outputs=outputs,
)
def generate_file_var(self, param_name: str, file: dict):
dify_ctx = self.require_dify_context()
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL:
file["upload_file_id"] = related_id
case FileTransferMethod.TOOL_FILE:
file["tool_file_id"] = related_id
case FileTransferMethod.DATASOURCE_FILE:
file["datasource_file_id"] = related_id
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=dify_ctx.tenant_id,
)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
except ValueError:
logger.error(
"Failed to build FileVariable for webhook file parameter %s",
param_name,
exc_info=True,
)
return None
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
"""Extract outputs based on node configuration from webhook inputs."""
outputs = {}
# Get the raw webhook data (should be injected by Celery task)
webhook_data = webhook_inputs.get("webhook_data", {})
def _to_sanitized(name: str) -> str:
return name.replace("-", "_")
def _get_normalized(mapping: dict[str, Any], key: str) -> Any:
if not isinstance(mapping, dict):
return None
if key in mapping:
return mapping[key]
alternate = key.replace("-", "_") if "-" in key else key.replace("_", "-")
if alternate in mapping:
return mapping[alternate]
return None
# Extract configured headers (case-insensitive)
webhook_headers = webhook_data.get("headers", {})
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
for header in self.node_data.headers:
header_name = header.name
value = _get_normalized(webhook_headers, header_name)
if value is None:
value = _get_normalized(webhook_headers_lower, header_name.lower())
sanitized_name = _to_sanitized(header_name)
outputs[sanitized_name] = value
# Extract configured query parameters
for param in self.node_data.params:
param_name = param.name
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
# Extract configured body parameters
for body_param in self.node_data.body:
param_name = body_param.name
param_type = body_param.type
if self.node_data.content_type == ContentType.TEXT:
# For text/plain, the entire body is a single string parameter
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self.node_data.content_type == ContentType.BINARY:
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
file_var = self.generate_file_var(param_name, raw_data)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = raw_data
continue
if param_type == SegmentType.FILE:
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data
return outputs

View File

@ -9,10 +9,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.sandbox import Sandbox
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.errors import WorkflowNodeRunFailedError
from dify_graph.file.models import File
from dify_graph.graph import Graph
@ -22,9 +22,8 @@ from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLay
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_engine.protocols.command_channel import CommandChannel
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from dify_graph.nodes import NodeType
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
@ -214,7 +213,7 @@ class WorkflowEntry:
node_config_data = node_config["data"]
# Get node type
node_type = NodeType(node_config_data["type"])
node_type = node_config_data.type
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@ -239,8 +238,7 @@ class WorkflowEntry:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
typed_node_config = cast(dict[str, object], node_config)
node = cast(Any, node_factory).create_node(typed_node_config)
node = node_factory.create_node(node_config)
node_cls = type(node)
try:
@ -259,7 +257,7 @@ class WorkflowEntry:
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
if node_type != NodeType.DATASOURCE:
if node_type != BuiltinNodeTypes.DATASOURCE:
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
@ -309,7 +307,7 @@ class WorkflowEntry:
"height": node_height,
"type": "custom",
"data": {
"type": NodeType.START,
"type": BuiltinNodeTypes.START,
"title": "Start",
"desc": "Start",
},
@ -345,11 +343,11 @@ class WorkflowEntry:
# Create a minimal graph for single node execution
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = NodeType(node_data.get("type", ""))
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
node_type = node_data.get("type", "")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1")
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
@ -376,10 +374,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config: NodeConfigDict = {
"id": node_id,
"data": cast(NodeConfigData, node_data),
}
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,