Merge branch 'feat/model-plugins-implementing' into deploy/dev

This commit is contained in:
yyh
2026-03-12 10:57:39 +08:00
101 changed files with 1831 additions and 1049 deletions

199
.github/dependabot.yml vendored
View File

@ -3,19 +3,210 @@ version: 2
updates:
- package-ecosystem: "pip"
directory: "/api"
open-pull-requests-limit: 2
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
python-dependencies:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
open-pull-requests-limit: 2
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
uv-dependencies:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "github-actions"
directory: "/"
open-pull-requests-limit: 5
schedule:
interval: "weekly"
groups:
github-actions-dependencies:
patterns:
- "*"

View File

@ -27,7 +27,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -39,7 +39,7 @@ jobs:
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'

View File

@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true
python-version: "3.12"
@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true

View File

@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: false
python-version: "3.12"

View File

@ -31,7 +31,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -32,6 +32,7 @@ from core.app.entities.queue_entities import (
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
@ -62,7 +63,6 @@ from dify_graph.graph_events import (
NodeRunSucceededEvent,
)
from dify_graph.graph_events.graph import GraphRunAbortedEvent
from dify_graph.nodes import NodeType
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
@ -303,9 +303,11 @@ class WorkflowBasedAppRunner:
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# Use the variable pool from graph_runtime_state instead of creating a new one

View File

@ -19,6 +19,7 @@ from core.trigger.debug.events import (
build_plugin_pool_key,
build_webhook_pool_key,
)
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig
@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC):
app_id: str
user_id: str
tenant_id: str
node_config: Mapping[str, Any]
node_config: NodeConfigDict
node_id: str
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str):
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str):
self.tenant_id = tenant_id
self.user_id = user_id
self.app_id = app_id
@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
def poll(self) -> TriggerDebugEvent | None:
from services.trigger.trigger_service import TriggerService
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {}))
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True)
provider_id = TriggerProviderID(plugin_trigger_data.provider_id)
pool_key: str = build_plugin_pool_key(
name=plugin_trigger_data.event_name,

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -22,7 +22,8 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import NodeType, SystemVariableKey
from dify_graph.file.file_manager import file_manager
@ -31,26 +32,19 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
from dify_graph.nodes.code.entities import CodeLanguage
from dify_graph.nodes.code.limits import CodeNodeLimits
from dify_graph.nodes.datasource import DatasourceNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from dify_graph.nodes.llm.entities import ModelConfig
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
from dify_graph.nodes.http_request import build_http_request_config
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
from dify_graph.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
from dify_graph.nodes.tool.tool_node import ToolNode
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
@ -60,6 +54,9 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
def fetch_memory(
*,
conversation_id: str | None,
@ -157,178 +154,128 @@ class DifyNodeFactory(NodeFactory):
return DifyRunContext.model_validate(raw_ctx)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
:raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
# Get node_id from config
node_id = node_config["id"]
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
NodeType.CODE: lambda: {
"code_executor": self._code_executor,
"code_limits": self._code_limits,
},
NodeType.TEMPLATE_TRANSFORM: lambda: {
"template_renderer": self._template_renderer,
"max_output_length": self._template_transform_max_output_length,
},
NodeType.HTTP_REQUEST: lambda: {
"http_request_config": self._http_request_config,
"http_client": self._http_request_http_client,
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
"file_manager": self._http_request_file_manager,
},
NodeType.HUMAN_INPUT: lambda: {
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
},
NodeType.KNOWLEDGE_INDEX: lambda: {
"index_processor": IndexProcessor(),
"summary_index_service": SummaryIndex(),
},
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
NodeType.DATASOURCE: lambda: {
"datasource_manager": DatasourceManager,
},
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
"rag_retrieval": self._rag_retrieval,
},
NodeType.DOCUMENT_EXTRACTOR: lambda: {
"unstructured_api_config": self._document_extractor_unstructured_api_config,
"http_client": self._http_request_http_client,
},
NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=False,
),
NodeType.TOOL: lambda: {
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
},
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
id=node_id,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
)
# Get node type from config
node_data = node_config["data"]
try:
node_type = NodeType(node_data["type"])
except ValueError:
raise ValueError(f"Unknown node type: {node_data['type']}")
@staticmethod
def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData:
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
return node_class.validate_node_data(node_data)
# Get node class
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
node_version = str(node_data.get("version", "1"))
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
return node_class
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
max_output_length=self._template_transform_max_output_length,
)
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_request_config=self._http_request_config,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.HUMAN_INPUT:
return HumanInputNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
)
if node_type == NodeType.KNOWLEDGE_INDEX:
return KnowledgeIndexNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
index_processor=IndexProcessor(),
summary_index_service=SummaryIndex(),
)
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return LLMNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.DATASOURCE:
return DatasourceNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
datasource_manager=DatasourceManager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
rag_retrieval=self._rag_retrieval,
)
if node_type == NodeType.DOCUMENT_EXTRACTOR:
return DocumentExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
unstructured_api_config=self._document_extractor_unstructured_api_config,
http_client=self._http_request_http_client,
)
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.TOOL:
return ToolNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
)
return node_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
def _build_llm_compatible_node_init_kwargs(
self,
*,
node_class: type[Node],
node_data: BaseNodeData,
include_http_client: bool,
) -> dict[str, object]:
validated_node_data = cast(
LLMCompatibleNodeData,
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
)
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,
"model_factory": self._llm_model_factory,
"model_instance": model_instance,
"memory": self._build_memory_for_llm_node(
node_data=validated_node_data,
model_instance=model_instance,
),
}
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
return node_init_kwargs
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
node_data_model = ModelConfig.model_validate(node_data["model"])
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
node_data_model = node_data.model
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
@ -364,14 +311,12 @@ class DifyNodeFactory(NodeFactory):
def _build_memory_for_llm_node(
self,
*,
node_data: Mapping[str, Any],
node_data: LLMCompatibleNodeData,
model_instance: ModelInstance,
) -> PromptMessageMemory | None:
raw_memory_config = node_data.get("memory")
if raw_memory_config is None:
if node_data.memory is None:
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
@ -381,6 +326,6 @@ class DifyNodeFactory(NodeFactory):
return fetch_memory(
conversation_id=conversation_id,
app_id=self._dify_context.app_id,
node_data_memory=node_memory,
node_data_memory=node_data.memory,
model_instance=model_instance,
)

View File

@ -11,7 +11,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.errors import WorkflowNodeRunFailedError
from dify_graph.file.models import File
from dify_graph.graph import Graph
@ -212,7 +212,7 @@ class WorkflowEntry:
node_config_data = node_config["data"]
# Get node type
node_type = NodeType(node_config_data["type"])
node_type = node_config_data.type
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@ -234,8 +234,7 @@ class WorkflowEntry:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
typed_node_config = cast(dict[str, object], node_config)
node = cast(Any, node_factory).create_node(typed_node_config)
node = node_factory.create_node(node_config)
node_cls = type(node)
try:
@ -371,10 +370,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config: NodeConfigDict = {
"id": node_id,
"data": cast(NodeConfigData, node_data),
}
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@ -0,0 +1,176 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator
from dify_graph.entities.exc import DefaultValueTypeError
from dify_graph.enums import ErrorStrategy, NodeType
# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`.
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
# and persisted templates/workflows also carry undeclared compatibility keys such as
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
# here until graph parsing becomes discriminated by node type or those legacy payloads
# are normalized.
model_config = ConfigDict(extra="allow")
type: NodeType
title: str = ""
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = Field(default_factory=RetryConfig)
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
def __getitem__(self, key: str) -> Any:
"""
Dict-style access without calling model_dump() on every lookup.
Prefer using model fields and Pydantic's extra storage.
"""
# First, check declared model fields
if key in self.__class__.model_fields:
return getattr(self, key)
# Then, check undeclared compatibility fields stored in Pydantic's extra dict.
extras = getattr(self, "__pydantic_extra__", None)
if extras is None:
extras = getattr(self, "model_extra", None)
if extras is not None and key in extras:
return extras[key]
raise KeyError(key)
def get(self, key: str, default: Any = None) -> Any:
"""
Dict-style .get() without calling model_dump() on every lookup.
"""
if key in self.__class__.model_fields:
return getattr(self, key)
extras = getattr(self, "__pydantic_extra__", None)
if extras is None:
extras = getattr(self, "model_extra", None)
if extras is not None and key in extras:
return extras.get(key, default)
return default

View File

@ -4,21 +4,20 @@ import sys
from pydantic import TypeAdapter, with_config
from dify_graph.entities.base_node_data import BaseNodeData
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
@with_config(extra="allow")
class NodeConfigData(TypedDict):
type: str
@with_config(extra="allow")
class NodeConfigDict(TypedDict):
id: str
data: NodeConfigData
# This is the permissive raw graph boundary. Node factories re-validate `data`
# with the concrete `NodeData` subtype after resolving the node implementation.
data: BaseNodeData
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)

View File

@ -8,7 +8,7 @@ from typing import Protocol, cast, final
from pydantic import TypeAdapter
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
from dify_graph.nodes.base.node import Node
from libs.typing import is_str
@ -34,7 +34,8 @@ class NodeFactory(Protocol):
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
:raises ValueError: if node type is unknown or no implementation exists for the resolved version
:raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation
"""
...
@ -115,10 +116,7 @@ class Graph:
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid]["data"]
node_type = node_data["type"]
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
if node_data.type.is_start_node:
start_node_id = nid
break
@ -203,6 +201,23 @@ class Graph:
return GraphBuilder(graph_cls=cls)
@staticmethod
def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]:
"""
Remove editor-only nodes before `NodeConfigDict` validation.
Persisted note widgets use a top-level `type == "custom-note"` but leave
`data.type` empty because they are never executable graph nodes. Filter
them while configs are still raw dicts so Pydantic does not validate
their placeholder payloads against `BaseNodeData.type: NodeType`.
"""
filtered_node_configs: list[dict[str, object]] = []
for node_config in node_configs:
if node_config.get("type", "") == "custom-note":
continue
filtered_node_configs.append(dict(node_config))
return filtered_node_configs
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
@ -302,13 +317,13 @@ class Graph:
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
node_configs = cls._filter_canvas_only_nodes(node_configs)
node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)

View File

@ -374,12 +374,11 @@ class AgentNode(Node[AgentNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AgentNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
result: dict[str, Any] = {}
typed_node_data = node_data
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:

View File

@ -5,10 +5,12 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class AgentNodeData(BaseNodeData):
type: NodeType = NodeType.AGENT
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str # redundancy

View File

@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: AnswerNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AnswerNodeData.model_validate(node_data)
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
_ = graph_config # Explicitly mark as unused
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}

View File

@ -3,7 +3,8 @@ from enum import StrEnum, auto
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class AnswerNodeData(BaseNodeData):
@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData):
Answer Node Data.
"""
type: NodeType = NodeType.ANSWER
answer: str = Field(..., description="answer template string")

View File

@ -1,4 +1,4 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [
@ -6,6 +6,5 @@ __all__ = [
"BaseIterationState",
"BaseLoopNodeData",
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
]

View File

@ -1,31 +1,12 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
from typing import Any
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel, field_validator
from dify_graph.enums import ErrorStrategy
from .exc import DefaultValueTypeError
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
from dify_graph.entities.base_node_data import BaseNodeData
class VariableSelector(BaseModel):
@ -76,112 +57,6 @@ class OutputVariableEntity(BaseModel):
return v
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
class BaseIterationNodeData(BaseNodeData):
start_node_id: str | None = None

View File

@ -12,6 +12,8 @@ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, ge
from uuid import uuid4
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import (
ErrorStrategy,
@ -62,8 +64,6 @@ from dify_graph.node_events import (
from dify_graph.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from .entities import BaseNodeData, RetryConfig
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
_MISSING_RUN_CONTEXT_VALUE = object()
@ -153,11 +153,11 @@ class Node(Generic[NodeDataT]):
Later, in __init__:
::
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
CodeNodeData instance
(stored in self._node_data)
config["data"] ──► _node_data_type.model_validate(..., from_attributes=True)
CodeNodeData instance
(stored in self._node_data)
Example:
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
@ -241,7 +241,7 @@ class Node(Generic[NodeDataT]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
@ -254,22 +254,21 @@ class Node(Generic[NodeDataT]):
self.graph_runtime_state = graph_runtime_state
self.state: NodeState = NodeState.UNKNOWN # node execution state
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
node_id = config["id"]
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
raise ValueError("Node config data must be a mapping.")
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
self._node_data = self.validate_node_data(config["data"])
self.post_init()
@classmethod
def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT:
"""Validate shared graph node payloads against the subclass-declared NodeData model."""
return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True))
def post_init(self) -> None:
"""Optional hook for subclasses requiring extra initialization."""
return
@ -342,9 +341,6 @@ class Node(Generic[NodeDataT]):
return None
return str(execution_id)
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
@ -389,8 +385,6 @@ class Node(Generic[NodeDataT]):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from typing import cast
from dify_graph.nodes.agent.agent_node import AgentNode
from dify_graph.nodes.agent.entities import AgentNodeData
@ -442,7 +436,7 @@ class Node(Generic[NodeDataT]):
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
config: NodeConfigDict,
) -> Mapping[str, Sequence[str]]:
"""Extracts references variable selectors from node configuration.
@ -480,13 +474,12 @@ class Node(Generic[NodeDataT]):
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
# Pass raw dict data instead of creating NodeData instance
node_id = config["id"]
node_data = cls.validate_node_data(config["data"])
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
graph_config=graph_config,
node_id=node_id,
node_data=node_data,
)
return data
@ -496,7 +489,7 @@ class Node(Generic[NodeDataT]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: NodeDataT,
) -> Mapping[str, Sequence[str]]:
return {}

View File

@ -3,6 +3,7 @@ from decimal import Decimal
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Protocol, cast
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: CodeNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
for variable_selector in node_data.variables
}
@property

View File

@ -3,7 +3,8 @@ from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.variables.types import SegmentType
@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = NodeType.CODE
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, "CodeNodeData.Output"] | None = None

View File

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
datasource_manager: DatasourceManagerProtocol,
@ -181,7 +182,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: DatasourceNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -190,11 +191,10 @@ class DatasourceNode(Node[DatasourceNodeData]):
:param node_data: node data
:return:
"""
typed_node_data = DatasourceNodeData.model_validate(node_data)
result = {}
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
input = node_data.datasource_parameters[parameter_name]
match input.type:
case "mixed":
assert isinstance(input.value, str)

View File

@ -3,7 +3,8 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class DatasourceEntity(BaseModel):
@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel):
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
type: NodeType = NodeType.DATASOURCE
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]

View File

@ -1,10 +1,12 @@
from collections.abc import Sequence
from dataclasses import dataclass
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class DocumentExtractorNodeData(BaseNodeData):
type: NodeType = NodeType.DOCUMENT_EXTRACTOR
variable_selector: Sequence[str]

View File

@ -21,6 +21,7 @@ from docx.oxml.text.paragraph import CT_P
from docx.table import Table
from docx.text.paragraph import Paragraph
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod, file_manager
from dify_graph.node_events import NodeRunResult
@ -54,7 +55,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -136,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
return {node_id + ".files": typed_node_data.variable_selector}
_ = graph_config # Explicitly mark as unused
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(

View File

@ -1,6 +1,8 @@
from pydantic import BaseModel, Field
from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import OutputVariableEntity
class EndNodeData(BaseNodeData):
@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData):
END Node Data.
"""
type: NodeType = NodeType.END
outputs: list[OutputVariableEntity]

View File

@ -8,7 +8,8 @@ import charset_normalizer
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = NodeType.HTTP_REQUEST
method: Literal[
"get",
"post",

View File

@ -3,6 +3,7 @@ import mimetypes
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod
from dify_graph.node_events import NodeRunResult
@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -163,18 +164,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: HttpRequestNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData.model_validate(node_data)
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if typed_node_data.body:
body_type = typed_node_data.body.type
data = typed_node_data.body.data
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
body_type = node_data.body.type
data = node_data.body.data
match body_type:
case "none":
pass

View File

@ -10,7 +10,8 @@ from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.runtime import VariablePool
from dify_graph.variables.consts import SELECTORS_LENGTH
@ -214,6 +215,7 @@ class UserAction(BaseModel):
class HumanInputNodeData(BaseNodeData):
"""Human Input node data."""
type: NodeType = NodeType.HUMAN_INPUT
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
form_content: str = ""
inputs: list[FormInput] = Field(default_factory=list)

View File

@ -3,6 +3,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import (
@ -63,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository,
@ -348,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: HumanInputNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
@ -357,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]):
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
validated_node_data = HumanInputNodeData.model_validate(node_data)
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
return node_data.extract_variable_selector_to_variable_mapping(node_id)

View File

@ -2,7 +2,8 @@ from typing import Literal
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.utils.condition.entities import Condition
@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData):
If Else Node Data.
"""
type: NodeType = NodeType.IF_ELSE
class Case(BaseModel):
"""
Case entity representing a single logical condition group

View File

@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData.model_validate(node_data)
var_mapping: dict[str, list[str]] = {}
for case in typed_node_data.cases or []:
_ = graph_config # Explicitly mark as unused
for case in node_data.cases or []:
for condition in case.conditions:
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
var_mapping[key] = condition.variable_selector

View File

@ -3,7 +3,9 @@ from typing import Any
from pydantic import Field
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState
class ErrorHandleMode(StrEnum):
@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData):
Iteration Node Data.
"""
type: NodeType = NodeType.ITERATION
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData):
Iteration Start Node Data.
"""
pass
type: NodeType = NodeType.ITERATION_START
class IterationState(BaseIterationState):

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from typing_extensions import TypeIs
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
NodeExecutionType,
NodeType,
@ -460,21 +461,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: IterationNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IterationNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
f"{node_id}.input_selector": node_data.iterator_selector,
}
iteration_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
node_config_data = node.get("data", {})
if node_config_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)
@ -490,14 +488,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# Get node class
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
node_type = typed_sub_node_config["data"].type
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_version = str(typed_sub_node_config["data"].version)
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
graph_config=graph_config, config=typed_sub_node_config
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:

View File

@ -3,7 +3,8 @@ from typing import Literal, Union
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class RerankingModelConfig(BaseModel):
@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData):
Knowledge index Node Data.
"""
type: str = "knowledge-index"
type: NodeType = NodeType.KNOWLEDGE_INDEX
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None

View File

@ -2,6 +2,7 @@ import logging
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.node_events import NodeRunResult
@ -30,7 +31,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
index_processor: IndexProcessorProtocol,

View File

@ -3,7 +3,8 @@ from typing import Literal
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
Knowledge retrieval Node Data.
"""
type: str = "knowledge-retrieval"
type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL
query_variable_selector: list[str] | None | str = None
query_attachment_selector: list[str] | None | str = None
dataset_ids: list[str]

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
@ -49,7 +50,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
@ -301,15 +302,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: KnowledgeRetrievalNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
if typed_node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
if node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
if node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector
return variable_mapping

View File

@ -3,7 +3,8 @@ from enum import StrEnum
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class FilterOperator(StrEnum):
@ -62,6 +63,7 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
type: NodeType = NodeType.LIST_OPERATOR
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderByConfig

View File

@ -4,8 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base import BaseNodeData
from dify_graph.nodes.base.entities import VariableSelector
@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
class LLMNodeData(BaseNodeData):
type: NodeType = NodeType.LLM
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)

View File

@ -21,6 +21,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
SystemVariableKey,
@ -121,7 +122,7 @@ class LLMNode(Node[LLMNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
@ -954,14 +955,11 @@ class LLMNode(Node[LLMNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
@ -979,7 +977,7 @@ class LLMNode(Node[LLMNodeData]):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = typed_node_data.memory
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@ -987,16 +985,16 @@ class LLMNode(Node[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if typed_node_data.context.enabled:
variable_mapping["#context#"] = typed_node_data.context.variable_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if typed_node_data.vision.enabled:
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if typed_node_data.memory:
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
if typed_node_data.prompt_config:
if node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
@ -1009,7 +1007,7 @@ class LLMNode(Node[LLMNodeData]):
break
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}

View File

@ -3,7 +3,9 @@ from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
@ -39,6 +41,7 @@ class LoopVariableData(BaseModel):
class LoopNodeData(BaseLoopNodeData):
type: NodeType = NodeType.LOOP
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData):
Loop Start Node Data.
"""
pass
type: NodeType = NodeType.LOOP_START
class LoopEndNodeData(BaseNodeData):
@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData):
Loop End Node Data.
"""
pass
type: NodeType = NodeType.LOOP_END
class LoopState(BaseLoopState):

View File

@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
NodeExecutionType,
NodeType,
@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = LoopNodeData.model_validate(node_data)
variable_mapping = {}
# Extract loop node IDs statically from graph_config
@ -320,14 +318,15 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
# Get node class
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
node_type = typed_sub_node_config["data"].type
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_version = str(typed_sub_node_config["data"].version)
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
graph_config=graph_config, config=typed_sub_node_config
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
@ -342,7 +341,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in typed_node_data.loop_variables or []:
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping

View File

@ -8,7 +8,8 @@ from pydantic import (
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
from dify_graph.variables.types import SegmentType
@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData):
Parameter Extractor Node Data.
"""
type: NodeType = NodeType.PARAMETER_EXTRACTOR
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]

View File

@ -10,6 +10,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -837,15 +838,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: ParameterExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
if node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector

View File

@ -1,7 +1,8 @@
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.llm import ModelConfig, VisionConfig
@ -11,6 +12,7 @@ class ClassConfig(BaseModel):
class QuestionClassifierNodeData(BaseNodeData):
type: NodeType = NodeType.QUESTION_CLASSIFIER
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]

View File

@ -7,6 +7,7 @@ from core.model_manager import ModelInstance
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeExecutionType,
NodeType,
@ -62,7 +63,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -251,16 +252,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: QuestionClassifierNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector}
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors: list[VariableSelector] = []
if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)

View File

@ -2,7 +2,8 @@ from collections.abc import Sequence
from pydantic import Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.variables.input_entities import VariableEntity
@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData):
Start Node Data
"""
type: NodeType = NodeType.START
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@ -1,4 +1,5 @@
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import VariableSelector
@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData):
Template Transform Node Data.
"""
type: NodeType = NodeType.TEMPLATE_TRANSFORM
variables: list[VariableSelector]
template: str

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@ -25,7 +26,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -86,12 +87,9 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
for variable_selector in node_data.variables
}

View File

@ -4,7 +4,8 @@ from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class ToolEntity(BaseModel):
@ -32,6 +33,8 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = NodeType.TOOL
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]

View File

@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
SystemVariableKey,
@ -46,7 +47,7 @@ class ToolNode(Node[ToolNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@ -484,7 +485,7 @@ class ToolNode(Node[ToolNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: ToolNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -493,9 +494,8 @@ class ToolNode(Node[ToolNodeData]):
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
typed_node_data = node_data
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]

View File

@ -4,13 +4,16 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.trigger.entities.entities import EventParameter
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError
class TriggerEventNodeData(BaseNodeData):
"""Plugin trigger node data"""
type: NodeType = NodeType.TRIGGER_PLUGIN
class TriggerEventInput(BaseModel):
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@ -38,8 +41,6 @@ class TriggerEventNodeData(BaseNodeData):
raise ValueError("value must be a string, int, float, bool or dict")
return type
title: str
desc: str | None = None
plugin_id: str = Field(..., description="Plugin ID")
provider_id: str = Field(..., description="Provider ID")
event_name: str = Field(..., description="Event name")

View File

@ -2,7 +2,8 @@ from typing import Literal, Union
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class TriggerScheduleNodeData(BaseNodeData):
@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData):
Trigger Schedule Node Data
"""
type: NodeType = NodeType.TRIGGER_SCHEDULE
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")

View File

@ -1,4 +1,4 @@
from dify_graph.nodes.base.exc import BaseNodeError
from dify_graph.entities.exc import BaseNodeError
class ScheduleNodeError(BaseNodeError):

View File

@ -1,10 +1,41 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel, Field, field_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.variables.types import SegmentType
_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
}
)
_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
}
)
_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES
_WEBHOOK_BODY_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_BOOLEAN,
SegmentType.ARRAY_OBJECT,
SegmentType.FILE,
}
)
class Method(StrEnum):
@ -25,29 +56,34 @@ class ContentType(StrEnum):
class WebhookParameter(BaseModel):
"""Parameter definition for headers, query params, or body."""
"""Parameter definition for headers or query params."""
name: str
type: SegmentType = SegmentType.STRING
required: bool = False
@field_validator("type", mode="after")
@classmethod
def validate_type(cls, v: SegmentType) -> SegmentType:
if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook parameter type: {v}")
return v
class WebhookBodyParameter(BaseModel):
"""Body parameter with type information."""
name: str
type: Literal[
"string",
"number",
"boolean",
"object",
"array[string]",
"array[number]",
"array[boolean]",
"array[object]",
"file",
] = "string"
type: SegmentType = SegmentType.STRING
required: bool = False
@field_validator("type", mode="after")
@classmethod
def validate_type(cls, v: SegmentType) -> SegmentType:
if v not in _WEBHOOK_BODY_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook body parameter type: {v}")
return v
class WebhookData(BaseNodeData):
"""
@ -57,6 +93,7 @@ class WebhookData(BaseNodeData):
class SyncMode(StrEnum):
SYNC = "async" # only support
type: NodeType = NodeType.TRIGGER_WEBHOOK
method: Method = Method.GET
content_type: ContentType = Field(default=ContentType.JSON)
headers: Sequence[WebhookParameter] = Field(default_factory=list)
@ -71,6 +108,22 @@ class WebhookData(BaseNodeData):
return v.lower()
return v
@field_validator("headers", mode="after")
@classmethod
def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
for param in v:
if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook header parameter type: {param.type}")
return v
@field_validator("params", mode="after")
@classmethod
def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
for param in v:
if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook query parameter type: {param.type}")
return v
status_code: int = 200 # Expected status code for response
response_body: str = "" # Template for response body

View File

@ -1,4 +1,4 @@
from dify_graph.nodes.base.exc import BaseNodeError
from dify_graph.entities.exc import BaseNodeError
class WebhookNodeError(BaseNodeError):

View File

@ -152,7 +152,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = raw_data
continue
if param_type == "file":
if param_type == SegmentType.FILE:
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):

View File

@ -1,6 +1,7 @@
from pydantic import BaseModel
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.variables.types import SegmentType
@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData):
Variable Aggregator Node Data.
"""
type: NodeType = NodeType.VARIABLE_AGGREGATOR
output_type: str
variables: list[list[str]]
advanced_settings: AdvancedSettings | None = None

View File

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@ -22,7 +23,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
@ -52,21 +53,18 @@ class VariableAssignerNode(Node[VariableAssignerData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: VariableAssignerData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerData.model_validate(node_data)
mapping = {}
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
assigned_variable_node_id = node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(typed_node_data.assigned_variable_selector)
selector_key = ".".join(node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.assigned_variable_selector
mapping[key] = node_data.assigned_variable_selector
selector_key = ".".join(typed_node_data.input_variable_selector)
selector_key = ".".join(node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.input_variable_selector
mapping[key] = node_data.input_variable_selector
return mapping
def _run(self) -> NodeRunResult:

View File

@ -1,7 +1,8 @@
from collections.abc import Sequence
from enum import StrEnum
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class WriteMode(StrEnum):
@ -11,6 +12,7 @@ class WriteMode(StrEnum):
class VariableAssignerData(BaseNodeData):
type: NodeType = NodeType.VARIABLE_ASSIGNER
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]

View File

@ -3,7 +3,8 @@ from typing import Any
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from .enums import InputType, Operation
@ -22,5 +23,6 @@ class VariableOperationItem(BaseModel):
class VariableAssignerNodeData(BaseNodeData):
type: NodeType = NodeType.VARIABLE_ASSIGNER
version: str = "2"
items: Sequence[VariableOperationItem] = Field(default_factory=list)

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@ -56,7 +57,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
@ -94,13 +95,10 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: VariableAssignerNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
var_mapping: dict[str, Sequence[str]] = {}
for item in typed_node_data.items:
for item in node_data.items:
_target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item)
return var_mapping

View File

@ -7,7 +7,7 @@ from celery.signals import worker_init
from flask_login import user_loaded_from_request, user_logged_in
from opentelemetry import trace
from opentelemetry.propagate import set_global_textmap
from opentelemetry.propagators.b3 import B3Format
from opentelemetry.propagators.b3 import B3MultiFormat
from opentelemetry.propagators.composite import CompositePropagator
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
@ -24,7 +24,7 @@ def setup_context_propagation() -> None:
CompositePropagator(
[
TraceContextTextMapPropagator(),
B3Format(),
B3MultiFormat(),
]
)
)

View File

@ -233,8 +233,11 @@ class Workflow(Base): # bug
def get_node_config_by_id(self, node_id: str) -> NodeConfigDict:
"""Extract a node configuration from the workflow graph by node ID.
A node configuration is a dictionary containing the node's properties, including
the node's id, title, and its data as a dict.
A node configuration includes the node id and a typed `BaseNodeData` for `data`.
`BaseNodeData` keeps a dict-like `get`/`__getitem__` compatibility layer backed by
model fields plus Pydantic extra storage for legacy consumers, but callers should
prefer attribute access.
"""
workflow_graph = self.graph_dict
@ -252,12 +255,9 @@ class Workflow(Base): # bug
return NodeConfigDictAdapter.validate_python(node_config)
@staticmethod
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
"""Extract type of a node from the node configuration returned by `get_node_config_by_id`."""
node_config_data = node_config.get("data", {})
# Get node class
node_type = NodeType(node_config_data.get("type"))
return node_type
return node_config["data"].type
@staticmethod
def get_enclosing_node_type_and_id(

View File

@ -5,42 +5,42 @@ requires-python = ">=3.11,<3.13"
dependencies = [
"aliyun-log-python-sdk~=0.9.37",
"arize-phoenix-otel~=0.9.2",
"azure-identity==1.16.1",
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.2",
"beautifulsoup4==4.12.2",
"boto3==1.35.99",
"boto3==1.42.65",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.5.2",
"charset-normalizer>=3.4.4",
"flask~=3.1.2",
"flask-compress>=1.17,<1.18",
"flask-compress>=1.17,<1.24",
"flask-cors~=6.0.0",
"flask-login~=0.6.3",
"flask-migrate~=4.0.7",
"flask-migrate~=4.1.0",
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gmpy2~=2.3.0",
"google-api-core>=2.19.1",
"google-api-python-client==2.189.0",
"google-api-python-client==2.192.0",
"google-auth>=2.47.0",
"google-auth-httplib2==0.2.0",
"google-auth-httplib2==0.3.0",
"google-cloud-aiplatform>=1.123.0",
"googleapis-common-protos>=1.65.0",
"gunicorn~=23.0.0",
"gunicorn~=25.1.0",
"httpx[socks]~=0.28.0",
"jieba==0.42.1",
"json-repair>=0.55.1",
"jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
"langsmith~=0.7.16",
"markdown~=3.8.1",
"mlflow-skinny>=3.0.0",
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.8.72",
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
"opik~=1.10.37",
"litellm==1.82.1", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.28.0",
"opentelemetry-distro==0.49b0",
"opentelemetry-exporter-otlp==1.28.0",
@ -53,7 +53,7 @@ dependencies = [
"opentelemetry-instrumentation-httpx==0.49b0",
"opentelemetry-instrumentation-redis==0.49b0",
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
"opentelemetry-propagator-b3==1.28.0",
"opentelemetry-propagator-b3==1.40.0",
"opentelemetry-proto==1.28.0",
"opentelemetry-sdk==1.28.0",
"opentelemetry-semantic-conventions==0.49b0",
@ -63,21 +63,21 @@ dependencies = [
"psycopg2-binary~=2.9.6",
"pycryptodome==3.23.0",
"pydantic~=2.12.5",
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.12.0",
"pydantic-extra-types~=2.11.0",
"pydantic-settings~=2.13.1",
"pyjwt~=2.11.0",
"pypdfium2==5.2.0",
"python-docx~=1.2.0",
"python-dotenv==1.0.1",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=7.2.0",
"redis[hiredis]~=7.3.0",
"resend~=2.9.0",
"sentry-sdk[flask]~=2.28.0",
"sqlalchemy~=2.0.29",
"starlette==0.49.1",
"tiktoken~=0.9.0",
"transformers~=4.56.1",
"tiktoken~=0.12.0",
"transformers~=5.3.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
@ -109,46 +109,46 @@ package = false
# Required for development and running tests
############################################################
dev = [
"coverage~=7.2.4",
"dotenv-linter~=0.5.0",
"faker~=38.2.0",
"coverage~=7.13.4",
"dotenv-linter~=0.7.0",
"faker~=40.8.0",
"lxml-stubs~=0.5.1",
"basedpyright~=1.38.2",
"ruff~=0.14.0",
"pytest~=8.3.2",
"pytest-benchmark~=4.0.0",
"pytest-cov~=4.1.0",
"ruff~=0.15.5",
"pytest~=9.0.2",
"pytest-benchmark~=5.2.3",
"pytest-cov~=7.0.0",
"pytest-env~=1.1.3",
"pytest-mock~=3.14.0",
"pytest-mock~=3.15.1",
"testcontainers~=4.13.2",
"types-aiofiles~=25.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=5.5.0",
"types-cachetools~=6.2.0",
"types-colorama~=0.4.15",
"types-defusedxml~=0.7.0",
"types-deprecated~=1.2.15",
"types-docutils~=0.21.0",
"types-jsonschema~=4.23.0",
"types-flask-cors~=5.0.0",
"types-deprecated~=1.3.1",
"types-docutils~=0.22.3",
"types-jsonschema~=4.26.0",
"types-flask-cors~=6.0.0",
"types-flask-migrate~=4.1.0",
"types-gevent~=25.9.0",
"types-greenlet~=3.3.0",
"types-html5lib~=1.1.11",
"types-markdown~=3.10.2",
"types-oauthlib~=3.2.0",
"types-oauthlib~=3.3.0",
"types-objgraph~=3.6.0",
"types-olefile~=0.47.0",
"types-openpyxl~=3.1.5",
"types-pexpect~=4.9.0",
"types-protobuf~=5.29.1",
"types-protobuf~=6.32.1",
"types-psutil~=7.2.2",
"types-psycopg2~=2.9.21",
"types-pygments~=2.19.0",
"types-pymysql~=1.1.0",
"types-python-dateutil~=2.9.0",
"types-pywin32~=310.0.0",
"types-pywin32~=311.0.0",
"types-pyyaml~=6.0.12",
"types-regex~=2024.11.6",
"types-regex~=2026.2.28",
"types-shapely~=2.1.0",
"types-simplejson>=3.20.0",
"types-six>=1.17.0",
@ -161,7 +161,7 @@ dev = [
"types_pyOpenSSL>=24.1.0",
"types_cffi>=1.17.0",
"types_setuptools>=80.9.0",
"pandas-stubs~=2.2.3",
"pandas-stubs~=3.0.0",
"scipy-stubs>=1.15.3.0",
"types-python-http-client>=3.3.7.20240910",
"import-linter>=2.3",
@ -180,13 +180,13 @@ dev = [
# Required for storage clients
############################################################
storage = [
"azure-storage-blob==12.26.0",
"azure-storage-blob==12.28.0",
"bce-python-sdk~=0.9.23",
"cos-python-sdk-v5==1.9.38",
"esdk-obs-python==3.25.8",
"cos-python-sdk-v5==1.9.41",
"esdk-obs-python==3.26.2",
"google-cloud-storage>=3.0.0",
"opendal~=0.46.0",
"oss2==2.18.5",
"oss2==2.19.1",
"supabase~=2.18.1",
"tos~=2.9.0",
]

View File

@ -1,14 +1,18 @@
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.nodes import NodeType
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
from dify_graph.nodes.trigger_schedule.entities import (
ScheduleConfig,
SchedulePlanUpdate,
TriggerScheduleNodeData,
VisualConfig,
)
from dify_graph.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
from models.account import Account, TenantAccountJoin
@ -176,26 +180,26 @@ class ScheduleService:
return next_run_at
@staticmethod
def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig:
def to_schedule_config(node_config: NodeConfigDict) -> ScheduleConfig:
"""
Converts user-friendly visual schedule settings to cron expression.
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
"""
node_data = node_config.get("data", {})
mode = node_data.get("mode", "visual")
timezone = node_data.get("timezone", "UTC")
node_id = node_config.get("id", "start")
node_data = TriggerScheduleNodeData.model_validate(node_config["data"], from_attributes=True)
mode = node_data.mode
timezone = node_data.timezone
node_id = node_config["id"]
cron_expression = None
if mode == "cron":
cron_expression = node_data.get("cron_expression")
cron_expression = node_data.cron_expression
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for cron mode")
elif mode == "visual":
frequency = str(node_data.get("frequency"))
frequency = str(node_data.frequency or "")
if not frequency:
raise ScheduleConfigError("Frequency is required for visual mode")
visual_config = VisualConfig(**node_data.get("visual_config", {}))
visual_config = VisualConfig.model_validate(node_data.visual_config or {})
cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config)
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for visual mode")
@ -239,19 +243,21 @@ class ScheduleService:
if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value:
continue
mode = node_data.get("mode", "visual")
timezone = node_data.get("timezone", "UTC")
node_id = node.get("id", "start")
trigger_data = TriggerScheduleNodeData.model_validate(node_data)
mode = trigger_data.mode
timezone = trigger_data.timezone
cron_expression = None
if mode == "cron":
cron_expression = node_data.get("cron_expression")
cron_expression = trigger_data.cron_expression
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for cron mode")
elif mode == "visual":
frequency = node_data.get("frequency")
visual_config_dict = node_data.get("visual_config", {})
visual_config = VisualConfig(**visual_config_dict)
frequency = trigger_data.frequency
if not frequency:
raise ScheduleConfigError("Frequency is required for visual mode")
visual_config = VisualConfig.model_validate(trigger_data.visual_config or {})
cron_expression = ScheduleService.visual_to_cron(frequency, visual_config)
else:
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")

View File

@ -16,6 +16,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
from extensions.ext_database import db
@ -41,7 +42,7 @@ class TriggerService:
@classmethod
def invoke_trigger_event(
cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent
cls, tenant_id: str, user_id: str, node_config: NodeConfigDict, event: PluginTriggerDebugEvent
) -> TriggerInvokeEventResponse:
"""Invoke a trigger event."""
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
@ -50,7 +51,7 @@ class TriggerService:
)
if not subscription:
raise ValueError("Subscription not found")
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {}))
node_data = TriggerEventNodeData.model_validate(node_config["data"], from_attributes=True)
request = TriggerHttpRequestCachingService.get_request(event.request_id)
payload = TriggerHttpRequestCachingService.get_payload(event.request_id)
# invoke triger

View File

@ -2,7 +2,7 @@ import json
import logging
import mimetypes
import secrets
from collections.abc import Mapping
from collections.abc import Callable, Mapping, Sequence
from typing import Any
import orjson
@ -16,9 +16,16 @@ from werkzeug.exceptions import RequestEntityTooLarge
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType
from dify_graph.file.models import FileTransferMethod
from dify_graph.variables.types import SegmentType
from dify_graph.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from dify_graph.variables.types import ArrayValidation, SegmentType
from enums.quota_type import QuotaType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -57,7 +64,7 @@ class WebhookService:
@classmethod
def get_webhook_trigger_and_workflow(
cls, webhook_id: str, is_debug: bool = False
) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]:
) -> tuple[WorkflowWebhookTrigger, Workflow, NodeConfigDict]:
"""Get webhook trigger, workflow, and node configuration.
Args:
@ -135,7 +142,7 @@ class WebhookService:
@classmethod
def extract_and_validate_webhook_data(
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any]
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
) -> dict[str, Any]:
"""Extract and validate webhook data in a single unified process.
@ -153,7 +160,7 @@ class WebhookService:
raw_data = cls.extract_webhook_data(webhook_trigger)
# Validate HTTP metadata (method, content-type)
node_data = node_config.get("data", {})
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
validation_result = cls._validate_http_metadata(raw_data, node_data)
if not validation_result["valid"]:
raise ValueError(validation_result["error"])
@ -192,7 +199,7 @@ class WebhookService:
content_type = cls._extract_content_type(dict(request.headers))
# Route to appropriate extractor based on content type
extractors = {
extractors: dict[str, Callable[[], tuple[dict[str, Any], dict[str, Any]]]] = {
"application/json": cls._extract_json_body,
"application/x-www-form-urlencoded": cls._extract_form_body,
"multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger),
@ -214,7 +221,7 @@ class WebhookService:
return data
@classmethod
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
"""Process and validate webhook data according to node configuration.
Args:
@ -230,18 +237,13 @@ class WebhookService:
result = raw_data.copy()
# Validate and process headers
cls._validate_required_headers(raw_data["headers"], node_data.get("headers", []))
cls._validate_required_headers(raw_data["headers"], node_data.headers)
# Process query parameters with type conversion and validation
result["query_params"] = cls._process_parameters(
raw_data["query_params"], node_data.get("params", []), is_form_data=True
)
result["query_params"] = cls._process_parameters(raw_data["query_params"], node_data.params, is_form_data=True)
# Process body parameters based on content type
configured_content_type = node_data.get("content_type", "application/json").lower()
result["body"] = cls._process_body_parameters(
raw_data["body"], node_data.get("body", []), configured_content_type
)
result["body"] = cls._process_body_parameters(raw_data["body"], node_data.body, node_data.content_type)
return result
@ -424,7 +426,11 @@ class WebhookService:
@classmethod
def _process_parameters(
cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False
cls,
raw_params: dict[str, str],
param_configs: Sequence[WebhookParameter],
*,
is_form_data: bool = False,
) -> dict[str, Any]:
"""Process parameters with unified validation and type conversion.
@ -440,13 +446,13 @@ class WebhookService:
ValueError: If required parameters are missing or validation fails
"""
processed = {}
configured_params = {config.get("name", ""): config for config in param_configs}
configured_params = {config.name: config for config in param_configs}
# Process configured parameters
for param_config in param_configs:
name = param_config.get("name", "")
param_type = param_config.get("type", SegmentType.STRING)
required = param_config.get("required", False)
name = param_config.name
param_type = param_config.type
required = param_config.required
# Check required parameters
if required and name not in raw_params:
@ -465,7 +471,10 @@ class WebhookService:
@classmethod
def _process_body_parameters(
cls, raw_body: dict[str, Any], body_configs: list, content_type: str
cls,
raw_body: dict[str, Any],
body_configs: Sequence[WebhookBodyParameter],
content_type: ContentType,
) -> dict[str, Any]:
"""Process body parameters based on content type and configuration.
@ -480,25 +489,28 @@ class WebhookService:
Raises:
ValueError: If required body parameters are missing or validation fails
"""
if content_type in ["text/plain", "application/octet-stream"]:
# For text/plain and octet-stream, validate required content exists
if body_configs and any(config.get("required", False) for config in body_configs):
raw_content = raw_body.get("raw")
if not raw_content:
raise ValueError(f"Required body content missing for {content_type} request")
return raw_body
match content_type:
case ContentType.TEXT | ContentType.BINARY:
# For text/plain and octet-stream, validate required content exists
if body_configs and any(config.required for config in body_configs):
raw_content = raw_body.get("raw")
if not raw_content:
raise ValueError(f"Required body content missing for {content_type} request")
return raw_body
case _:
pass
# For structured data (JSON, form-data, etc.)
processed = {}
configured_params = {config.get("name", ""): config for config in body_configs}
configured_params: dict[str, WebhookBodyParameter] = {config.name: config for config in body_configs}
for body_config in body_configs:
name = body_config.get("name", "")
param_type = body_config.get("type", SegmentType.STRING)
required = body_config.get("required", False)
name = body_config.name
param_type = body_config.type
required = body_config.required
# Handle file parameters for multipart data
if param_type == SegmentType.FILE and content_type == "multipart/form-data":
if param_type == SegmentType.FILE and content_type == ContentType.FORM_DATA:
# File validation is handled separately in extract phase
continue
@ -508,7 +520,7 @@ class WebhookService:
if name in raw_body:
raw_value = raw_body[name]
is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"]
is_form_data = content_type in [ContentType.FORM_URLENCODED, ContentType.FORM_DATA]
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
# Include unconfigured parameters
@ -519,7 +531,9 @@ class WebhookService:
return processed
@classmethod
def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any:
def _validate_and_convert_value(
cls, param_name: str, value: Any, param_type: SegmentType | str, is_form_data: bool
) -> Any:
"""Unified validation and type conversion for parameter values.
Args:
@ -532,7 +546,8 @@ class WebhookService:
Any: The validated and converted value
Raises:
ValueError: If validation or conversion fails
ValueError: If validation or conversion fails. The original validation
error is preserved as ``__cause__`` for debugging.
"""
try:
if is_form_data:
@ -542,10 +557,10 @@ class WebhookService:
# JSON data should already be in correct types, just validate
return cls._validate_json_value(param_name, value, param_type)
except Exception as e:
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}")
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}") from e
@classmethod
def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any:
def _convert_form_value(cls, param_name: str, value: str, param_type: SegmentType | str) -> Any:
"""Convert form data string values to specified types.
Args:
@ -576,7 +591,7 @@ class WebhookService:
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
@classmethod
def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any:
def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any:
"""Validate JSON values against expected types.
Args:
@ -590,43 +605,43 @@ class WebhookService:
Raises:
ValueError: If the value type doesn't match the expected type
"""
type_validators = {
SegmentType.STRING: (lambda v: isinstance(v, str), "string"),
SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"),
SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"),
SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"),
SegmentType.ARRAY_STRING: (
lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v),
"array of strings",
),
SegmentType.ARRAY_NUMBER: (
lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v),
"array of numbers",
),
SegmentType.ARRAY_BOOLEAN: (
lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v),
"array of booleans",
),
SegmentType.ARRAY_OBJECT: (
lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v),
"array of objects",
),
}
validator_info = type_validators.get(SegmentType(param_type))
if not validator_info:
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name)
if param_type_enum is None:
return value
validator, expected_type = validator_info
if not validator(value):
if not param_type_enum.is_valid(value, array_validation=ArrayValidation.ALL):
actual_type = type(value).__name__
expected_type = cls._expected_type_label(param_type_enum)
raise ValueError(f"Expected {expected_type}, got {actual_type}")
return value
@classmethod
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None:
def _coerce_segment_type(cls, param_type: SegmentType | str, *, param_name: str) -> SegmentType | None:
if isinstance(param_type, SegmentType):
return param_type
try:
return SegmentType(param_type)
except Exception:
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
return None
@staticmethod
def _expected_type_label(param_type: SegmentType) -> str:
match param_type:
case SegmentType.ARRAY_STRING:
return "array of strings"
case SegmentType.ARRAY_NUMBER:
return "array of numbers"
case SegmentType.ARRAY_BOOLEAN:
return "array of booleans"
case SegmentType.ARRAY_OBJECT:
return "array of objects"
case _:
return param_type.value
@classmethod
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: Sequence[WebhookParameter]) -> None:
"""Validate required headers are present.
Args:
@ -639,14 +654,14 @@ class WebhookService:
headers_lower = {k.lower(): v for k, v in headers.items()}
headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()}
for header_config in header_configs:
if header_config.get("required", False):
header_name = header_config.get("name", "")
if header_config.required:
header_name = header_config.name
sanitized_name = cls._sanitize_key(header_name).lower()
if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized:
raise ValueError(f"Required header missing: {header_name}")
@classmethod
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
"""Validate HTTP method and content-type.
Args:
@ -657,13 +672,13 @@ class WebhookService:
dict[str, Any]: Validation result with 'valid' key and optional 'error' key
"""
# Validate HTTP method
configured_method = node_data.get("method", "get").upper()
configured_method = node_data.method.value.upper()
request_method = webhook_data["method"].upper()
if configured_method != request_method:
return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}")
# Validate Content-type
configured_content_type = node_data.get("content_type", "application/json").lower()
configured_content_type = node_data.content_type.value.lower()
request_content_type = cls._extract_content_type(webhook_data["headers"])
if configured_content_type != request_content_type:
@ -788,7 +803,7 @@ class WebhookService:
raise
@classmethod
def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]:
def generate_webhook_response(cls, node_config: NodeConfigDict) -> tuple[dict[str, Any], int]:
"""Generate HTTP response based on node configuration.
Args:
@ -797,11 +812,11 @@ class WebhookService:
Returns:
tuple[dict[str, Any], int]: Response data and HTTP status code
"""
node_data = node_config.get("data", {})
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
# Get configured status code and response body
status_code = node_data.get("status_code", 200)
response_body = node_data.get("response_body", "")
status_code = node_data.status_code
response_body = node_data.response_body
# Parse response body as JSON if it's valid JSON, otherwise return as text
try:

View File

@ -16,6 +16,7 @@ from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams, WorkflowNodeExecution
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.errors import WorkflowNodeRunFailedError
@ -693,7 +694,7 @@ class WorkflowService:
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
node_data = node_config.get("data", {})
node_data = node_config["data"]
if node_type.is_start_node:
with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
@ -703,7 +704,7 @@ class WorkflowService:
workflow=draft_workflow,
)
if node_type is NodeType.START:
start_data = StartNodeData.model_validate(node_data)
start_data = StartNodeData.model_validate(node_data, from_attributes=True)
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
@ -941,7 +942,7 @@ class WorkflowService:
if node_type is not NodeType.HUMAN_INPUT:
raise ValueError("Node type must be human-input.")
node_data = HumanInputNodeData.model_validate(node_config.get("data", {}))
node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True)
delivery_method = self._resolve_human_input_delivery_method(
node_data=node_data,
delivery_method_id=delivery_method_id,
@ -1059,7 +1060,7 @@ class WorkflowService:
*,
workflow: Workflow,
account: Account,
node_config: Mapping[str, Any],
node_config: NodeConfigDict,
variable_pool: VariablePool,
) -> HumanInputNode:
graph_init_params = GraphInitParams(
@ -1079,7 +1080,7 @@ class WorkflowService:
start_at=time.perf_counter(),
)
node = HumanInputNode(
id=node_config.get("id", str(uuid.uuid4())),
id=node_config["id"],
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@ -1092,7 +1093,7 @@ class WorkflowService:
*,
app_model: App,
workflow: Workflow,
node_config: Mapping[str, Any],
node_config: NodeConfigDict,
manual_inputs: Mapping[str, Any],
) -> VariablePool:
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():

View File

@ -189,6 +189,7 @@ def test_custom_authorization_header(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
from dify_graph.enums import NodeType
from dify_graph.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@ -209,6 +210,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
# Create node data with custom auth and empty api_key
node_data = HttpRequestNodeData(
type=NodeType.HTTP_REQUEST,
title="http",
desc="",
url="http://example.com",

View File

@ -173,7 +173,7 @@ class TestWebhookService:
assert workflow.app_id == test_data["app"].id
assert node_config is not None
assert node_config["id"] == "webhook_node"
assert node_config["data"]["title"] == "Test Webhook"
assert node_config["data"].title == "Test Webhook"
def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers):
"""Test webhook trigger not found scenario."""

View File

@ -25,7 +25,8 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing
# load dotenv file with pydantic-settings
config = DifyConfig()
# Disable `.env` loading to ensure test stability across environments
config = DifyConfig(_env_file=None)
# constant values
assert config.COMMIT_SHA == ""
@ -59,7 +60,8 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
config = DifyConfig()
# Disable `.env` loading to ensure test stability across environments
config = DifyConfig(_env_file=None)
# Verify default timeout values
assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10
@ -86,7 +88,8 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("WEB_API_CORS_ALLOW_ORIGINS", "http://127.0.0.1:3000,*")
monkeypatch.setenv("CODE_EXECUTION_ENDPOINT", "http://127.0.0.1:8194/")
flask_app.config.from_mapping(DifyConfig().model_dump()) # pyright: ignore
# Disable `.env` loading to ensure test stability across environments
flask_app.config.from_mapping(DifyConfig(_env_file=None).model_dump()) # pyright: ignore
config = flask_app.config
# configs read from pydantic-settings

View File

@ -51,7 +51,7 @@ def bypass_decorators(mocker):
)
mocker.patch(
"controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check",
return_value=lambda *_: (lambda f: f),
return_value=lambda *_: lambda f: f,
)

View File

@ -8,6 +8,8 @@ from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import SchedulingPause
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
@ -22,7 +24,7 @@ from dify_graph.graph_events import (
NodeRunSucceededEvent,
)
from dify_graph.node_events import NodeRunResult, PauseRequestedEvent
from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig
from dify_graph.nodes.base.entities import OutputVariableEntity
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.end.entities import EndNodeData
from dify_graph.nodes.start.entities import StartNodeData
@ -42,6 +44,7 @@ if "core.ops.ops_trace_manager" not in sys.modules:
class _StubToolNodeData(BaseNodeData):
type: NodeType = NodeType.TOOL
pause_on: bool = False
@ -88,16 +91,17 @@ class _StubToolNode(Node[_StubToolNodeData]):
def _patch_tool_node(mocker):
original_create_node = DifyNodeFactory.create_node
def _patched_create_node(self, node_config: dict[str, object]) -> Node:
node_data = node_config.get("data", {})
if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value:
def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
node_data = typed_node_config["data"]
if node_data.type == NodeType.TOOL:
return _StubToolNode(
id=str(node_config["id"]),
config=node_config,
id=str(typed_node_config["id"]),
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
return original_create_node(self, node_config)
return original_create_node(self, typed_node_config)
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)

View File

@ -7,7 +7,9 @@ import pytest
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from models.workflow import Workflow
@ -105,3 +107,57 @@ def test_run_uses_single_node_execution_branch(
assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER
assert entry_kwargs["variable_pool"] is variable_pool
assert entry_kwargs["graph_runtime_state"] is graph_runtime_state
def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
runner = WorkflowBasedAppRunner(
queue_manager=MagicMock(spec=AppQueueManager),
variable_loader=MagicMock(),
app_id="app",
)
workflow = MagicMock(spec=Workflow)
workflow.id = "workflow"
workflow.tenant_id = "tenant"
workflow.graph_dict = {
"nodes": [
{
"id": "loop-node",
"data": {
"type": "loop",
"title": "Loop",
"loop_count": 1,
"break_conditions": [],
"logical_operator": "and",
},
}
],
"edges": [],
}
_, _, graph_runtime_state = _make_graph_state()
seen_configs: list[object] = []
original_validate_python = NodeConfigDictAdapter.validate_python
def record_validate_python(value: object):
seen_configs.append(value)
return original_validate_python(value)
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
with (
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
):
runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id="loop-node",
user_inputs={},
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
)
assert seen_configs == [workflow.graph_dict["nodes"][0]]

View File

@ -7,10 +7,10 @@ from dataclasses import dataclass
import pytest
from dify_graph.entities import GraphInitParams
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType
from dify_graph.graph import Graph
from dify_graph.graph.validation import GraphValidationError
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.nodes.base.node import Node
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
@ -183,3 +183,36 @@ def test_graph_validation_blocks_start_and_trigger_coexistence(
Graph.init(graph_config=graph_config, node_factory=node_factory)
assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues)
def test_graph_init_ignores_custom_note_nodes_before_node_data_validation(
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
) -> None:
node_factory, graph_config = graph_init_dependencies
graph_config["nodes"] = [
{
"id": "start",
"data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT},
},
{"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}},
{
"id": "note",
"type": "custom-note",
"data": {
"type": "",
"title": "",
"desc": "",
"text": "{}",
"theme": "blue",
},
},
]
graph_config["edges"] = [
{"source": "start", "target": "answer", "sourceHandle": "success"},
]
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
assert graph.root_node.id == "start"
assert "answer" in graph.nodes
assert "note" not in graph.nodes

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from dify_graph.entities.base_node_data import RetryConfig
from dify_graph.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from dify_graph.graph import Graph
from dify_graph.graph_engine.domain.graph_execution import GraphExecution
@ -12,7 +13,6 @@ from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.entities import RetryConfig
from dify_graph.runtime import GraphRuntimeState, VariablePool
from libs.datetime_utils import naive_utc_now

View File

@ -10,6 +10,7 @@ import time
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType
from dify_graph.enums import ErrorStrategy
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
from dify_graph.graph_engine.command_channels import InMemoryChannel
@ -18,7 +19,6 @@ from dify_graph.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from dify_graph.nodes.base.entities import DefaultValue, DefaultValueType
# Import the test framework from the new module
from .test_mock_config import MockConfigBuilder

View File

@ -5,10 +5,10 @@ This module provides a MockNodeFactory that automatically detects and mocks node
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
"""
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.enums import NodeType
from dify_graph.nodes.base.node import Node
@ -75,39 +75,27 @@ class MockNodeFactory(DifyNodeFactory):
NodeType.CODE: MockCodeNode,
}
def create_node(self, node_config: Mapping[str, Any]) -> Node:
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
"""
Create a node instance, using mock implementations for third-party service nodes.
:param node_config: Node configuration dictionary
:return: Node instance (real or mocked)
"""
# Get node type from config
node_data = node_config.get("data", {})
node_type_str = node_data.get("type")
if not node_type_str:
# Fall back to parent implementation for nodes without type
return super().create_node(node_config)
try:
node_type = NodeType(node_type_str)
except ValueError:
# Unknown node type, use parent implementation
return super().create_node(node_config)
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
node_data = typed_node_config["data"]
node_type = node_data.type
# Check if this node type should be mocked
if node_type in self._mock_node_types:
node_id = node_config.get("id")
if not node_id:
raise ValueError("Node config missing id")
node_id = typed_node_config["id"]
# Create mock node instance
mock_class = self._mock_node_types[node_type]
if node_type == NodeType.CODE:
mock_instance = mock_class(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -117,7 +105,7 @@ class MockNodeFactory(DifyNodeFactory):
elif node_type == NodeType.HTTP_REQUEST:
mock_instance = mock_class(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -129,7 +117,7 @@ class MockNodeFactory(DifyNodeFactory):
elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}:
mock_instance = mock_class(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -139,7 +127,7 @@ class MockNodeFactory(DifyNodeFactory):
else:
mock_instance = mock_class(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -148,7 +136,7 @@ class MockNodeFactory(DifyNodeFactory):
return mock_instance
# For non-mocked node types, use parent implementation
return super().create_node(node_config)
return super().create_node(typed_node_config)
def should_mock_node(self, node_type: NodeType) -> bool:
"""

View File

@ -1,7 +1,7 @@
import pytest
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.nodes.base.node import Node
# Ensures that all node classes are imported.
@ -126,3 +126,20 @@ def test_init_subclass_sets_node_data_type_from_generic():
return "1"
assert _AutoNode._node_data_type is _TestNodeData
def test_validate_node_data_uses_declared_node_data_type():
"""Public validation should hydrate the subclass-declared node data model."""
class _AutoNode(Node[_TestNodeData]):
node_type = NodeType.CODE
@staticmethod
def version() -> str:
return "1"
base_node_data = BaseNodeData.model_validate({"type": NodeType.CODE, "title": "Test"})
validated = _AutoNode.validate_node_data(base_node_data)
assert isinstance(validated, _TestNodeData)

View File

@ -1,8 +1,8 @@
import types
from collections.abc import Mapping
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.nodes.base.node import Node
# Import concrete nodes we will assert on (numeric version path)

View File

@ -272,7 +272,7 @@ class TestCodeNodeExtractVariableSelector:
result = CodeNode._extract_variable_selector_to_variable_mapping(
graph_config={},
node_id="node_1",
node_data=node_data,
node_data=CodeNodeData.model_validate(node_data, from_attributes=True),
)
assert result == {}
@ -292,7 +292,7 @@ class TestCodeNodeExtractVariableSelector:
result = CodeNode._extract_variable_selector_to_variable_mapping(
graph_config={},
node_id="node_1",
node_data=node_data,
node_data=CodeNodeData.model_validate(node_data, from_attributes=True),
)
assert "node_1.input_text" in result
@ -315,7 +315,7 @@ class TestCodeNodeExtractVariableSelector:
result = CodeNode._extract_variable_selector_to_variable_mapping(
graph_config={},
node_id="code_node",
node_data=node_data,
node_data=CodeNodeData.model_validate(node_data, from_attributes=True),
)
assert len(result) == 3
@ -338,7 +338,7 @@ class TestCodeNodeExtractVariableSelector:
result = CodeNode._extract_variable_selector_to_variable_mapping(
graph_config={},
node_id="node_x",
node_data=node_data,
node_data=CodeNodeData.model_validate(node_data, from_attributes=True),
)
assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"]
@ -437,7 +437,7 @@ class TestCodeNodeInitialization:
"outputs": {"x": {"type": "number"}},
}
node._node_data = node._hydrate_node_data(data)
node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True)
assert node._node_data.title == "Test Node"
assert node._node_data.code_language == CodeLanguage.PYTHON3
@ -453,7 +453,7 @@ class TestCodeNodeInitialization:
"outputs": {"x": {"type": "number"}},
}
node._node_data = node._hydrate_node_data(data)
node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True)
assert node._node_data.code_language == CodeLanguage.JAVASCRIPT

View File

@ -1,3 +1,4 @@
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import NodeType
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from dify_graph.nodes.iteration.exc import (
@ -388,3 +389,50 @@ class TestIterationNodeErrorStrategies:
result = node._get_default_value_dict()
assert isinstance(result, dict)
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
seen_configs: list[object] = []
original_validate_python = NodeConfigDictAdapter.validate_python
def record_validate_python(value: object):
seen_configs.append(value)
return original_validate_python(value)
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
child_node_config = {
"id": "answer-node",
"data": {
"type": "answer",
"title": "Answer",
"answer": "",
"iteration_id": "iteration-node",
},
}
IterationNode._extract_variable_selector_to_variable_mapping(
graph_config={
"nodes": [
{
"id": "iteration-node",
"data": {
"type": "iteration",
"title": "Iteration",
"iterator_selector": ["start", "items"],
"output_selector": ["iteration", "result"],
},
},
child_node_config,
],
"edges": [],
},
node_id="iteration-node",
node_data=IterationNodeData(
title="Iteration",
iterator_selector=["start", "items"],
output_selector=["iteration", "result"],
),
)
assert seen_configs == [child_node_config]

View File

@ -410,14 +410,14 @@ class TestKnowledgeRetrievalNode:
"""Test _extract_variable_selector_to_variable_mapping class method."""
# Arrange
node_id = "knowledge_node_1"
node_data = {
"type": "knowledge-retrieval",
"title": "Knowledge Retrieval",
"dataset_ids": [str(uuid.uuid4())],
"retrieval_mode": "multiple",
"query_variable_selector": ["start", "query"],
"query_attachment_selector": ["start", "attachments"],
}
node_data = KnowledgeRetrievalNodeData(
type="knowledge-retrieval",
title="Knowledge Retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="multiple",
query_variable_selector=["start", "query"],
query_attachment_selector=["start", "attachments"],
)
graph_config = {}
# Act

View File

@ -4,8 +4,9 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from dify_graph.entities import GraphInitParams
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.nodes.base.node import Node
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
@ -40,13 +41,26 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
def _build_node_config() -> NodeConfigDict:
return NodeConfigDictAdapter.validate_python(
{
"id": "node-1",
"data": {
"type": NodeType.ANSWER.value,
"title": "Sample",
"foo": "bar",
},
}
)
def test_node_hydrates_data_during_initialization():
graph_config: dict[str, object] = {}
init_params, runtime_state = _build_context(graph_config)
node = _SampleNode(
id="node-1",
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
config=_build_node_config(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@ -72,7 +86,7 @@ def test_node_accepts_invoke_from_enum():
node = _SampleNode(
id="node-1",
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
config=_build_node_config(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@ -99,3 +113,17 @@ def test_missing_generic_argument_raises_type_error():
def _run(self):
raise NotImplementedError
def test_base_node_data_keeps_dict_style_access_compatibility():
node_data = _SampleNodeData.model_validate(
{
"type": NodeType.ANSWER.value,
"title": "Sample",
"foo": "bar",
}
)
assert node_data["foo"] == "bar"
assert node_data.get("foo") == "bar"
assert node_data.get("missing", "fallback") == "fallback"

View File

@ -0,0 +1,52 @@
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.nodes.loop.entities import LoopNodeData
from dify_graph.nodes.loop.loop_node import LoopNode
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
seen_configs: list[object] = []
original_validate_python = NodeConfigDictAdapter.validate_python
def record_validate_python(value: object):
seen_configs.append(value)
return original_validate_python(value)
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
child_node_config = {
"id": "answer-node",
"data": {
"type": "answer",
"title": "Answer",
"answer": "",
"loop_id": "loop-node",
},
}
LoopNode._extract_variable_selector_to_variable_mapping(
graph_config={
"nodes": [
{
"id": "loop-node",
"data": {
"type": "loop",
"title": "Loop",
"loop_count": 1,
"break_conditions": [],
"logical_operator": "and",
},
},
child_node_config,
],
"edges": [],
},
node_id="loop-node",
node_data=LoopNodeData(
title="Loop",
loop_count=1,
break_conditions=[],
logical_operator="and",
),
)
assert seen_configs == [child_node_config]

View File

@ -210,9 +210,6 @@ def test_webhook_data_model_dump_with_alias():
def test_webhook_data_validation_errors():
"""Test WebhookData validation errors."""
# Title is required (inherited from BaseNodeData)
with pytest.raises(ValidationError):
WebhookData()
# Invalid method
with pytest.raises(ValidationError):
@ -254,6 +251,36 @@ def test_webhook_data_sequence_fields():
assert len(data.headers) == 1 # Should still be 1
def test_webhook_data_rejects_non_string_header_types():
"""Headers should stay string-only because runtime does not coerce header values."""
for param_type in ["number", "boolean", "object", "array[string]", "file"]:
with pytest.raises(ValidationError):
WebhookData(
title="Test",
headers=[WebhookParameter(name="X-Test", type=param_type)],
)
def test_webhook_data_limits_query_param_types_to_scalar_values():
"""Query params only support scalar conversions in the current runtime."""
data = WebhookData(
title="Test",
params=[
WebhookParameter(name="count", type="number"),
WebhookParameter(name="enabled", type="boolean"),
],
)
assert data.params[0].type == "number"
assert data.params[1].type == "boolean"
for param_type in ["object", "array[string]", "array[number]", "array[boolean]", "array[object]", "file"]:
with pytest.raises(ValidationError):
WebhookData(
title="Test",
params=[WebhookParameter(name="test", type=param_type)],
)
def test_webhook_data_sync_mode():
"""Test WebhookData SyncMode nested enum."""
# Test that SyncMode enum exists and has expected value
@ -297,7 +324,7 @@ def test_webhook_body_parameter_edge_cases():
def test_webhook_data_inheritance():
"""Test WebhookData inherits from BaseNodeData correctly."""
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
# Test that WebhookData is a subclass of BaseNodeData
assert issubclass(WebhookData, BaseNodeData)

View File

@ -1,6 +1,6 @@
import pytest
from dify_graph.nodes.base.exc import BaseNodeError
from dify_graph.entities.exc import BaseNodeError
from dify_graph.nodes.trigger_webhook.exc import (
WebhookConfigError,
WebhookNodeError,

View File

@ -0,0 +1,82 @@
from __future__ import annotations
from typing import Any
from core.model_manager import ModelInstance
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from tests.workflow_test_utils import build_test_graph_init_params
def _build_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
graph_init_params = build_test_graph_init_params(
workflow_id="workflow",
graph_config=graph_config,
tenant_id="tenant",
app_id="app",
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=[],
),
start_at=0.0,
)
return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
def test_create_node_uses_declared_node_data_type_for_llm_validation(monkeypatch):
class _FactoryLLMNodeData(LLMNodeData):
pass
llm_node_config = {
"id": "llm-node",
"data": {
"type": "llm",
"title": "LLM",
"model": {
"provider": "openai",
"name": "gpt-4o-mini",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [],
"context": {
"enabled": False,
},
},
}
graph_config = {"nodes": [llm_node_config], "edges": []}
factory = _build_factory(graph_config)
captured: dict[str, object] = {}
monkeypatch.setattr(LLMNode, "_node_data_type", _FactoryLLMNodeData)
def _capture_model_instance(self: DifyNodeFactory, node_data: object) -> ModelInstance:
captured["node_data"] = node_data
return object() # type: ignore[return-value]
def _capture_memory(
self: DifyNodeFactory,
*,
node_data: object,
model_instance: ModelInstance,
) -> None:
captured["memory_node_data"] = node_data
monkeypatch.setattr(DifyNodeFactory, "_build_model_instance_for_llm_node", _capture_model_instance)
monkeypatch.setattr(DifyNodeFactory, "_build_memory_for_llm_node", _capture_memory)
node = factory.create_node(llm_node_config)
assert isinstance(captured["node_data"], _FactoryLLMNodeData)
assert isinstance(captured["memory_node_data"], _FactoryLLMNodeData)
assert isinstance(node.node_data, _FactoryLLMNodeData)

View File

@ -9,6 +9,7 @@ from dify_graph.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.file.enums import FileType
from dify_graph.file.models import File, FileTransferMethod
from dify_graph.nodes.code.code_node import CodeNode
@ -124,7 +125,7 @@ class TestWorkflowEntry:
def get_node_config_by_id(self, target_id: str):
assert target_id == node_id
return node_config
return NodeConfigDictAdapter.validate_python(node_config)
workflow = StubWorkflow()
variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={})

View File

@ -258,38 +258,38 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
return q
msg_session_1 = MagicMock()
msg_session_1.query.side_effect = (
lambda model: make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
msg_session_1.query.side_effect = lambda model: (
make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
)
msg_session_1.commit.return_value = None
msg_session_2 = MagicMock()
msg_session_2.query.side_effect = (
lambda model: make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
msg_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
)
msg_session_2.commit.return_value = None
conv_session_1 = MagicMock()
conv_session_1.query.side_effect = (
lambda model: make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
conv_session_1.query.side_effect = lambda model: (
make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
)
conv_session_1.commit.return_value = None
conv_session_2 = MagicMock()
conv_session_2.query.side_effect = (
lambda model: make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
conv_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
)
conv_session_2.commit.return_value = None
wal_session_1 = MagicMock()
wal_session_1.query.side_effect = (
lambda model: make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
wal_session_1.query.side_effect = lambda model: (
make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
)
wal_session_1.commit.return_value = None
wal_session_2 = MagicMock()
wal_session_2.query.side_effect = (
lambda model: make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
wal_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
)
wal_session_2.commit.return_value = None

View File

@ -5,6 +5,7 @@ from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import sessionmaker
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.enums import NodeType
from dify_graph.nodes.human_input.entities import (
EmailDeliveryConfig,
@ -22,7 +23,7 @@ def _make_service() -> WorkflowService:
return WorkflowService(session_maker=sessionmaker())
def _build_node_config(delivery_methods):
def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict:
node_data = HumanInputNodeData(
title="Human Input",
delivery_methods=delivery_methods,
@ -31,7 +32,7 @@ def _build_node_config(delivery_methods):
user_actions=[],
).model_dump(mode="json")
node_data["type"] = NodeType.HUMAN_INPUT.value
return {"id": "node-1", "data": node_data}
return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data})
def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod:

View File

@ -4,6 +4,7 @@ from unittest.mock import MagicMock
import pytest
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import NodeType
from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction
from dify_graph.nodes.human_input.enums import FormInputType
@ -187,7 +188,10 @@ class TestWorkflowService:
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
node_config = NodeConfigDictAdapter.validate_python(
{"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
)
workflow.get_node_config_by_id.return_value = node_config
workflow.get_enclosing_node_type_and_id.return_value = None
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
@ -232,7 +236,7 @@ class TestWorkflowService:
service._build_human_input_variable_pool.assert_called_once_with(
app_model=app_model,
workflow=workflow,
node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}},
node_config=node_config,
manual_inputs={"#node-0.result#": "LLM output"},
)
@ -267,7 +271,9 @@ class TestWorkflowService:
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
{"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
)
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")

659
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -47,7 +47,7 @@ const SettingRow = memo(({
: 'bg-workflow-block-parma-bg',
)}
>
<div className={cn('mr-2 shrink-0 system-xs-medium-uppercase', warning ? 'text-text-warning' : 'text-text-tertiary')}>
<div className="mr-2 shrink-0 text-text-tertiary system-xs-medium-uppercase">
{label}
</div>
<div
@ -141,13 +141,16 @@ const Node: FC<NodeProps<KnowledgeBaseNodeType>> = ({ data }) => {
if (data.indexing_technique !== IndexMethodEnum.QUALIFIED)
return '-'
if (isKnowledgeBaseEmbeddingIssue(validationIssue))
if (isKnowledgeBaseEmbeddingIssue(validationIssue)) {
if (validationIssue?.code === KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
return t('nodes.knowledgeBase.notConfigured', { ns: 'workflow' })
return validationIssueMessage
}
const currentEmbeddingModelProvider = embeddingModelList.find(provider => provider.provider === data.embedding_model_provider)
const currentEmbeddingModel = currentEmbeddingModelProvider?.models.find(model => model.model === data.embedding_model)
return currentEmbeddingModel?.label[language] || currentEmbeddingModel?.label.en_US || data.embedding_model || '-'
}, [data.embedding_model, data.embedding_model_provider, data.indexing_technique, embeddingModelList, language, validationIssue, validationIssueMessage])
}, [data.embedding_model, data.embedding_model_provider, data.indexing_technique, embeddingModelList, language, validationIssue, validationIssueMessage, t])
const indexMethodDisplay = settingsDisplay[data.indexing_technique as keyof typeof settingsDisplay] || '-'
const retrievalMethodDisplay = settingsDisplay[data.retrieval_model?.search_method as keyof typeof settingsDisplay] || '-'

View File

@ -687,6 +687,7 @@
"nodes.knowledgeBase.embeddingModelIsRequired": "Embedding model is required",
"nodes.knowledgeBase.embeddingModelNotConfigured": "Embedding model not configured",
"nodes.knowledgeBase.indexMethodIsRequired": "Index method is required",
"nodes.knowledgeBase.notConfigured": "Not configured",
"nodes.knowledgeBase.rerankingModelIsInvalid": "Reranking model is invalid",
"nodes.knowledgeBase.rerankingModelIsRequired": "Reranking model is required",
"nodes.knowledgeBase.retrievalSettingIsRequired": "Retrieval setting is required",

Some files were not shown because too many files have changed in this diff Show More