From fb41b215c8a57ab2a60c312cfe141fdb9f004c76 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 15 Mar 2026 15:24:59 +0800 Subject: [PATCH] refactor(api): move workflow knowledge nodes and trigger nodes (#33445) --- api/.importlinter | 2 - api/controllers/console/app/app.py | 9 +- api/controllers/console/app/workflow.py | 3 +- .../advanced_chat/generate_task_pipeline.py | 4 +- .../common/workflow_response_converter.py | 27 +-- api/core/app/apps/pipeline/pipeline_runner.py | 4 +- api/core/app/apps/workflow_app_runner.py | 11 +- api/core/app/entities/queue_entities.py | 3 +- .../conversation_variable_persist_layer.py | 4 +- api/core/app/workflow/layers/llm_quota.py | 8 +- api/core/app/workflow/layers/observability.py | 13 +- api/core/datasource/datasource_manager.py | 2 +- api/core/ops/aliyun_trace/aliyun_trace.py | 8 +- .../arize_phoenix_trace.py | 6 +- api/core/ops/langfuse_trace/langfuse_trace.py | 4 +- .../ops/langsmith_trace/langsmith_trace.py | 6 +- api/core/ops/mlflow_trace/mlflow_trace.py | 24 +-- api/core/ops/opik_trace/opik_trace.py | 4 +- api/core/ops/tencent_trace/tencent_trace.py | 10 +- api/core/ops/weave_trace/weave_trace.py | 4 +- api/core/plugin/backwards_invocation/node.py | 4 +- .../rag/index_processor/index_processor.py | 4 +- api/core/rag/retrieval/dataset_retrieval.py | 14 +- ...hemy_workflow_node_execution_repository.py | 4 +- .../utils/workflow_configuration_sync.py | 4 +- api/core/trigger/constants.py | 18 ++ api/core/trigger/debug/event_selectors.py | 38 ++-- api/core/workflow/__init__.py | 5 +- api/core/workflow/node_factory.py | 173 +++++++++++++++--- api/core/workflow/node_resolution.py | 42 ----- api/core/workflow/nodes/__init__.py | 1 + api/core/workflow/nodes/agent/agent_node.py | 4 +- api/core/workflow/nodes/agent/entities.py | 4 +- .../nodes/agent/message_transformer.py | 4 +- .../workflow/nodes/datasource/__init__.py | 1 + .../nodes/datasource/datasource_node.py | 16 +- .../workflow}/nodes/datasource/entities.py | 15 +- .../workflow}/nodes/datasource/exc.py | 0 .../workflow/nodes/datasource/protocols.py} | 17 +- .../nodes/knowledge_index/__init__.py | 5 + .../nodes/knowledge_index/entities.py | 3 +- .../workflow}/nodes/knowledge_index/exc.py | 0 .../knowledge_index/knowledge_index_node.py | 15 +- .../nodes/knowledge_index/protocols.py} | 22 ++- .../nodes/knowledge_retrieval/__init__.py | 1 + .../nodes/knowledge_retrieval/entities.py | 4 +- .../nodes/knowledge_retrieval/exc.py | 0 .../knowledge_retrieval_node.py | 16 +- .../nodes/knowledge_retrieval/retrieval.py} | 30 +-- .../knowledge_retrieval/template_prompts.py | 0 .../nodes/trigger_plugin/__init__.py | 0 .../nodes/trigger_plugin/entities.py | 6 +- .../workflow}/nodes/trigger_plugin/exc.py | 0 .../trigger_plugin/trigger_event_node.py | 12 +- .../nodes/trigger_schedule/__init__.py | 3 + .../nodes/trigger_schedule/entities.py | 3 +- .../workflow}/nodes/trigger_schedule/exc.py | 0 .../trigger_schedule/trigger_schedule_node.py | 10 +- .../nodes/trigger_webhook/__init__.py | 0 .../nodes/trigger_webhook/entities.py | 3 +- .../workflow}/nodes/trigger_webhook/exc.py | 0 .../workflow}/nodes/trigger_webhook/node.py | 5 +- api/core/workflow/workflow_entry.py | 13 +- api/dify_graph/README.md | 2 +- api/dify_graph/entities/base_node_data.py | 2 + .../entities/workflow_node_execution.py | 2 +- api/dify_graph/enums.py | 113 +++++++----- api/dify_graph/graph/graph.py | 54 +----- api/dify_graph/graph/validation.py | 42 +---- .../response_coordinator/__init__.py | 3 +- .../response_coordinator/session.py | 45 ++++- api/dify_graph/node_events/node.py | 6 +- api/dify_graph/nodes/__init__.py | 4 +- api/dify_graph/nodes/answer/answer_node.py | 4 +- api/dify_graph/nodes/answer/entities.py | 4 +- api/dify_graph/nodes/base/node.py | 50 ++--- api/dify_graph/nodes/code/code_node.py | 4 +- api/dify_graph/nodes/code/entities.py | 4 +- api/dify_graph/nodes/datasource/__init__.py | 3 - .../nodes/document_extractor/entities.py | 4 +- .../nodes/document_extractor/node.py | 4 +- api/dify_graph/nodes/end/end_node.py | 4 +- api/dify_graph/nodes/end/entities.py | 4 +- api/dify_graph/nodes/http_request/entities.py | 4 +- api/dify_graph/nodes/http_request/node.py | 4 +- api/dify_graph/nodes/human_input/entities.py | 4 +- .../nodes/human_input/human_input_node.py | 4 +- api/dify_graph/nodes/if_else/entities.py | 4 +- api/dify_graph/nodes/if_else/if_else_node.py | 4 +- api/dify_graph/nodes/iteration/entities.py | 6 +- .../nodes/iteration/iteration_node.py | 11 +- .../nodes/iteration/iteration_start_node.py | 4 +- .../nodes/knowledge_index/__init__.py | 3 - .../nodes/knowledge_retrieval/__init__.py | 3 - .../nodes/list_operator/entities.py | 4 +- api/dify_graph/nodes/list_operator/node.py | 4 +- api/dify_graph/nodes/llm/entities.py | 4 +- api/dify_graph/nodes/llm/node.py | 55 +++--- api/dify_graph/nodes/loop/entities.py | 8 +- api/dify_graph/nodes/loop/loop_end_node.py | 4 +- api/dify_graph/nodes/loop/loop_node.py | 13 +- api/dify_graph/nodes/loop/loop_start_node.py | 4 +- api/dify_graph/nodes/node_mapping.py | 28 --- .../nodes/parameter_extractor/entities.py | 4 +- .../parameter_extractor_node.py | 4 +- .../nodes/question_classifier/entities.py | 4 +- .../question_classifier_node.py | 4 +- api/dify_graph/nodes/start/entities.py | 4 +- api/dify_graph/nodes/start/start_node.py | 4 +- .../nodes/template_transform/entities.py | 4 +- .../template_transform_node.py | 4 +- api/dify_graph/nodes/tool/entities.py | 4 +- api/dify_graph/nodes/tool/tool_node.py | 4 +- .../nodes/trigger_schedule/__init__.py | 3 - .../nodes/variable_aggregator/entities.py | 4 +- .../variable_aggregator_node.py | 4 +- .../nodes/variable_assigner/v1/node.py | 4 +- .../nodes/variable_assigner/v1/node_data.py | 4 +- .../nodes/variable_assigner/v2/entities.py | 4 +- .../nodes/variable_assigner/v2/node.py | 4 +- .../summary_index_service_protocol.py | 7 - ...rameters_cache_when_sync_draft_workflow.py | 4 +- ...nc_workflow_schedule_when_app_published.py | 2 +- ...oin_when_app_published_workflow_updated.py | 6 +- ...ers_when_app_published_workflow_updated.py | 4 +- ...tore_workflow_node_execution_repository.py | 5 +- api/extensions/otel/parser/base.py | 22 +-- api/models/enums.py | 12 +- api/models/workflow.py | 19 +- api/pyrefly-local-excludes.txt | 2 +- api/services/app_dsl_service.py | 35 ++-- api/services/app_service.py | 5 +- api/services/rag_pipeline/rag_pipeline.py | 16 +- .../rag_pipeline/rag_pipeline_dsl_service.py | 33 ++-- api/services/trigger/schedule_service.py | 10 +- api/services/trigger/trigger_service.py | 6 +- api/services/trigger/webhook_service.py | 10 +- api/services/workflow/workflow_converter.py | 20 +- .../workflow_draft_variable_service.py | 19 +- api/services/workflow_service.py | 46 +++-- api/tasks/trigger_processing_tasks.py | 7 +- api/tasks/workflow_node_execution_tasks.py | 2 +- api/tasks/workflow_schedule_tasks.py | 2 +- .../test_datasource_node_integration.py | 20 +- .../test_workflow_draft_variable_service.py | 6 +- .../workflow/nodes/test_code.py | 2 +- .../workflow/nodes/test_http.py | 8 +- .../workflow/nodes/test_template_transform.py | 2 +- .../workflow/nodes/test_tool.py | 2 +- .../test_dataset_retrieval_integration.py | 2 +- .../test_human_input_delivery_test.py | 4 +- .../services/test_workflow_service.py | 26 +-- .../trigger/test_trigger_e2e.py | 39 ++-- .../test_generate_task_pipeline_core.py | 26 +-- ..._workflow_response_converter_truncation.py | 34 ++-- .../core/app/apps/test_base_app_generator.py | 4 +- .../core/app/apps/test_pause_resume.py | 22 +-- .../app/apps/test_workflow_app_runner_core.py | 12 +- .../test_generate_task_pipeline_core.py | 30 +-- ...est_conversation_variable_persist_layer.py | 10 +- .../ops/aliyun_trace/test_aliyun_trace.py | 12 +- .../ops/langfuse_trace/test_langfuse_trace.py | 8 +- .../langsmith_trace/test_langsmith_trace.py | 10 +- .../ops/mlflow_trace/test_mlflow_trace.py | 26 +-- .../core/ops/opik_trace/test_opik_trace.py | 8 +- .../ops/tencent_trace/test_tencent_trace.py | 16 +- .../core/ops/test_arize_phoenix_trace.py | 18 +- .../core/ops/weave_trace/test_weave_trace.py | 12 +- .../rag/retrieval/test_dataset_retrieval.py | 4 +- .../test_dataset_retrieval_methods.py | 4 +- ...lery_workflow_node_execution_repository.py | 12 +- ...rkflow_node_execution_conflict_handling.py | 10 +- ...test_workflow_node_execution_truncation.py | 9 +- .../test_trigger_debug_event_selectors.py | 2 +- .../debug/test_debug_event_selectors.py | 19 +- .../entities/test_workflow_node_execution.py | 6 +- .../core/workflow/graph/test_graph.py | 4 +- .../core/workflow/graph/test_graph_builder.py | 10 +- .../graph/test_graph_skip_validation.py | 6 +- .../workflow/graph/test_graph_validation.py | 79 ++++---- .../event_management/test_event_handlers.py | 4 +- .../workflow/graph_engine/layers/conftest.py | 12 +- .../graph_engine/layers/test_llm_quota.py | 18 +- .../graph_engine/layers/test_observability.py | 8 +- .../orchestration/test_dispatcher.py | 14 +- .../graph_engine/test_auto_mock_system.py | 48 ++--- ...ditional_streaming_vs_template_workflow.py | 26 ++- .../test_dispatcher_pause_drain.py | 4 +- .../test_graph_execution_serialization.py | 8 +- .../workflow/graph_engine/test_mock_config.py | 2 - .../graph_engine/test_mock_factory.py | 36 ++-- .../test_mock_iteration_simple.py | 10 +- .../workflow/graph_engine/test_mock_nodes.py | 2 +- .../test_mock_nodes_template_code.py | 14 +- .../workflow/graph_engine/test_mock_simple.py | 42 ++--- .../test_parallel_streaming_workflow.py | 20 +- .../graph_engine/test_response_session.py | 71 +++++++ .../graph_engine/test_table_runner.py | 8 +- .../core/workflow/nodes/answer/test_answer.py | 2 +- .../workflow/nodes/base/test_base_node.py | 22 +-- .../test_get_node_type_classes_mapping.py | 17 +- .../nodes/datasource/test_datasource_node.py | 5 +- .../test_human_input_form_filled_event.py | 6 +- .../nodes/iteration/iteration_node_spec.py | 4 +- .../test_knowledge_index_node.py | 50 ++--- .../test_knowledge_retrieval_node.py | 31 ++-- .../workflow/nodes/list_operator/node_spec.py | 4 +- .../template_transform_node_spec.py | 4 +- .../core/workflow/nodes/test_base_node.py | 10 +- .../nodes/test_document_extractor_node.py | 4 +- .../core/workflow/nodes/test_if_else.py | 10 +- .../v1/test_variable_assigner_v1.py | 6 +- .../v2/test_variable_assigner_v2.py | 8 +- .../workflow/nodes/webhook/test_entities.py | 2 +- .../workflow/nodes/webhook/test_exceptions.py | 6 +- .../webhook/test_webhook_file_conversion.py | 20 +- .../nodes/webhook/test_webhook_node.py | 13 +- .../core/workflow/test_node_factory.py | 85 ++++----- .../workflow/test_node_mapping_bootstrap.py | 43 +++++ .../workflow/test_workflow_entry_helpers.py | 21 ++- .../libs/test_cron_compatibility.py | 2 +- .../unit_tests/models/test_workflow_models.py | 42 ++--- .../test_sqlalchemy_repository.py | 8 +- ...hemy_workflow_node_execution_repository.py | 4 +- .../services/test_app_dsl_service.py | 35 ++-- .../services/test_schedule_service.py | 8 +- .../services/test_workflow_service.py | 48 ++--- .../test_workflow_draft_variable_service.py | 14 +- .../test_workflow_human_input_delivery.py | 4 +- .../workflow/test_workflow_service.py | 10 +- .../test_workflow_node_execution_tasks.py | 6 +- dev/pyrefly-check-local | 2 + 232 files changed, 1575 insertions(+), 1421 deletions(-) create mode 100644 api/core/trigger/constants.py delete mode 100644 api/core/workflow/node_resolution.py create mode 100644 api/core/workflow/nodes/datasource/__init__.py rename api/{dify_graph => core/workflow}/nodes/datasource/datasource_node.py (94%) rename api/{dify_graph => core/workflow}/nodes/datasource/entities.py (85%) rename api/{dify_graph => core/workflow}/nodes/datasource/exc.py (100%) rename api/{dify_graph/repositories/datasource_manager_protocol.py => core/workflow/nodes/datasource/protocols.py} (79%) create mode 100644 api/core/workflow/nodes/knowledge_index/__init__.py rename api/{dify_graph => core/workflow}/nodes/knowledge_index/entities.py (96%) rename api/{dify_graph => core/workflow}/nodes/knowledge_index/exc.py (100%) rename api/{dify_graph => core/workflow}/nodes/knowledge_index/knowledge_index_node.py (91%) rename api/{dify_graph/repositories/index_processor_protocol.py => core/workflow/nodes/knowledge_index/protocols.py} (55%) create mode 100644 api/core/workflow/nodes/knowledge_retrieval/__init__.py rename api/{dify_graph => core/workflow}/nodes/knowledge_retrieval/entities.py (96%) rename api/{dify_graph => core/workflow}/nodes/knowledge_retrieval/exc.py (100%) rename api/{dify_graph => core/workflow}/nodes/knowledge_retrieval/knowledge_retrieval_node.py (97%) rename api/{dify_graph/repositories/rag_retrieval_protocol.py => core/workflow/nodes/knowledge_retrieval/retrieval.py} (83%) rename api/{dify_graph => core/workflow}/nodes/knowledge_retrieval/template_prompts.py (100%) rename api/{dify_graph => core/workflow}/nodes/trigger_plugin/__init__.py (100%) rename api/{dify_graph => core/workflow}/nodes/trigger_plugin/entities.py (95%) rename api/{dify_graph => core/workflow}/nodes/trigger_plugin/exc.py (100%) rename api/{dify_graph => core/workflow}/nodes/trigger_plugin/trigger_event_node.py (84%) create mode 100644 api/core/workflow/nodes/trigger_schedule/__init__.py rename api/{dify_graph => core/workflow}/nodes/trigger_schedule/entities.py (94%) rename api/{dify_graph => core/workflow}/nodes/trigger_schedule/exc.py (100%) rename api/{dify_graph => core/workflow}/nodes/trigger_schedule/trigger_schedule_node.py (85%) rename api/{dify_graph => core/workflow}/nodes/trigger_webhook/__init__.py (100%) rename api/{dify_graph => core/workflow}/nodes/trigger_webhook/entities.py (97%) rename api/{dify_graph => core/workflow}/nodes/trigger_webhook/exc.py (100%) rename api/{dify_graph => core/workflow}/nodes/trigger_webhook/node.py (97%) delete mode 100644 api/dify_graph/nodes/datasource/__init__.py delete mode 100644 api/dify_graph/nodes/knowledge_index/__init__.py delete mode 100644 api/dify_graph/nodes/knowledge_retrieval/__init__.py delete mode 100644 api/dify_graph/nodes/node_mapping.py delete mode 100644 api/dify_graph/nodes/trigger_schedule/__init__.py delete mode 100644 api/dify_graph/repositories/summary_index_service_protocol.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py create mode 100644 api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py diff --git a/api/.importlinter b/api/.importlinter index 8dffc3506b..4109c007d9 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -96,7 +96,6 @@ ignore_imports = dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler dify_graph.nodes.tool.tool_node -> core.tools.tool_engine dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model @@ -116,7 +115,6 @@ ignore_imports = dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods dify_graph.nodes.llm.node -> models.dataset dify_graph.nodes.llm.file_saver -> core.tools.signature dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 33b3c9ec36..5ac0e342e6 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -25,7 +25,8 @@ from controllers.console.wraps import ( ) from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.enums import NodeType, WorkflowExecutionStatus +from core.trigger.constants import TRIGGER_NODE_TYPES +from dify_graph.enums import WorkflowExecutionStatus from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required @@ -508,11 +509,7 @@ class AppListApi(Resource): .scalars() .all() ) - trigger_node_types = { - NodeType.TRIGGER_WEBHOOK, - NodeType.TRIGGER_SCHEDULE, - NodeType.TRIGGER_PLUGIN, - } + trigger_node_types = TRIGGER_NODE_TYPES for workflow in draft_workflows: node_id = None try: diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 9759e0815a..837245ecb1 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -22,6 +22,7 @@ from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.trace_id_helper import get_external_trace_id from core.plugin.impl.exc import PluginInvokeError +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.trigger.debug.event_selectors import ( TriggerDebugEvent, TriggerDebugEventPoller, @@ -1209,7 +1210,7 @@ class DraftWorkflowTriggerNodeApi(Resource): node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config) event: TriggerDebugEvent | None = None # for schedule trigger, when run single node, just execute directly - if node_type == NodeType.TRIGGER_SCHEDULE: + if node_type == TRIGGER_SCHEDULE_NODE_TYPE: event = TriggerDebugEvent( workflow_args={}, node_id=node_id, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index a1cb375e24..6583ba51e9 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -69,7 +69,7 @@ from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import WorkflowExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory from dify_graph.runtime import GraphRuntimeState from dify_graph.system_variable import SystemVariable @@ -357,7 +357,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) -> Generator[StreamResponse, None, None]: """Handle node succeeded events.""" # Record files if it's an answer node or end node - if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]: + if event.node_type in [BuiltinNodeTypes.ANSWER, BuiltinNodeTypes.END, BuiltinNodeTypes.LLM]: self._recorded_files.extend( self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) ) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 67dc9909a1..5665a2b76c 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -48,12 +48,13 @@ from core.app.entities.task_entities import ( from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import ( - NodeType, + BuiltinNodeTypes, SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, @@ -442,7 +443,7 @@ class WorkflowResponseConverter: event: QueueNodeStartedEvent, task_id: str, ) -> NodeStartStreamResponse | None: - if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: + if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}: return None run_id = self._ensure_workflow_run_id() snapshot = self._store_snapshot(event) @@ -464,13 +465,13 @@ class WorkflowResponseConverter: ) try: - if event.node_type == NodeType.TOOL: + if event.node_type == BuiltinNodeTypes.TOOL: response.data.extras["icon"] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=ToolProviderType(event.provider_type), provider_id=event.provider_id, ) - elif event.node_type == NodeType.DATASOURCE: + elif event.node_type == BuiltinNodeTypes.DATASOURCE: manager = PluginDatasourceManager() provider_entity = manager.fetch_datasource_provider( self._application_generate_entity.app_config.tenant_id, @@ -479,7 +480,7 @@ class WorkflowResponseConverter: response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url( self._application_generate_entity.app_config.tenant_id ) - elif event.node_type == NodeType.TRIGGER_PLUGIN: + elif event.node_type == TRIGGER_PLUGIN_NODE_TYPE: response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon( self._application_generate_entity.app_config.tenant_id, event.provider_id, @@ -496,7 +497,7 @@ class WorkflowResponseConverter: event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent, task_id: str, ) -> NodeFinishStreamResponse | None: - if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: + if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}: return None run_id = self._ensure_workflow_run_id() snapshot = self._pop_snapshot(event.node_execution_id) @@ -554,7 +555,7 @@ class WorkflowResponseConverter: event: QueueNodeRetryEvent, task_id: str, ) -> NodeRetryStreamResponse | None: - if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: + if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}: return None run_id = self._ensure_workflow_run_id() @@ -612,7 +613,7 @@ class WorkflowResponseConverter: data=IterationNodeStartStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, created_at=int(time.time()), extras={}, @@ -635,7 +636,7 @@ class WorkflowResponseConverter: data=IterationNodeNextStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, index=event.index, created_at=int(time.time()), @@ -662,7 +663,7 @@ class WorkflowResponseConverter: data=IterationNodeCompletedStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, outputs=new_outputs, outputs_truncated=outputs_truncated, @@ -692,7 +693,7 @@ class WorkflowResponseConverter: data=LoopNodeStartStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, created_at=int(time.time()), extras={}, @@ -715,7 +716,7 @@ class WorkflowResponseConverter: data=LoopNodeNextStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, index=event.index, # The `pre_loop_output` field is not utilized by the frontend. @@ -744,7 +745,7 @@ class WorkflowResponseConverter: data=LoopNodeCompletedStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, outputs=new_outputs, outputs_truncated=outputs_truncated, diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 4222aae809..e767766bdb 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import ( build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.enums import WorkflowType @@ -274,6 +274,8 @@ class PipelineRunner(WorkflowBasedAppRunner): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) + if start_node_id is None: + start_node_id = get_default_root_node_id(graph_config) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id) if not graph: diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 8986164fe7..25d3c8bd2a 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -32,8 +32,8 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.node_resolution import resolve_workflow_node_class +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDictAdapter @@ -140,6 +140,9 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, ) + if root_node_id is None: + root_node_id = get_default_root_node_id(graph_config) + # init graph graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) @@ -505,7 +508,9 @@ class WorkflowBasedAppRunner: elif isinstance(event, NodeRunRetrieverResourceEvent): self._publish_event( QueueRetrieverResourcesEvent( - retriever_resources=event.retriever_resources, + retriever_resources=[ + RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources + ], in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, ) diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 2d1508f0cb..8899d80db8 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -9,9 +9,8 @@ from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata from dify_graph.entities.pause_reason import PauseReason from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowNodeExecutionMetadataKey +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.nodes import NodeType class QueueEvent(StrEnum): diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index e495abf855..d227e4e904 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -2,7 +2,7 @@ import logging from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.conversation_variable_updater import ConversationVariableUpdater -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent from dify_graph.nodes.variable_assigner.common import helpers as common_helpers @@ -22,7 +22,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): def on_event(self, event: GraphEngineEvent) -> None: if not isinstance(event, NodeRunSucceededEvent): return - if event.node_type != NodeType.VARIABLE_ASSIGNER: + if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: return if self.graph_runtime_state is None: return diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 2e930a1f58..faf1516c40 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -12,7 +12,7 @@ from typing_extensions import override from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase @@ -113,11 +113,11 @@ class LLMQuotaLayer(GraphEngineLayer): def _extract_model_instance(node: Node) -> ModelInstance | None: try: match node.node_type: - case NodeType.LLM: + case BuiltinNodeTypes.LLM: return cast("LLMNode", node).model_instance - case NodeType.PARAMETER_EXTRACTOR: + case BuiltinNodeTypes.PARAMETER_EXTRACTOR: return cast("ParameterExtractorNode", node).model_instance - case NodeType.QUESTION_CLASSIFIER: + case BuiltinNodeTypes.QUESTION_CLASSIFIER: return cast("QuestionClassifierNode", node).model_instance case _: return None diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index ab73db59f1..4b20477a7f 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -16,7 +16,7 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_ from typing_extensions import override from configs import dify_config -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import GraphNodeEventBase from dify_graph.nodes.base.node import Node @@ -74,16 +74,13 @@ class ObservabilityLayer(GraphEngineLayer): def _build_parser_registry(self) -> None: """Initialize parser registry for node types.""" self._parsers = { - NodeType.TOOL: ToolNodeOTelParser(), - NodeType.LLM: LLMNodeOTelParser(), - NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(), + BuiltinNodeTypes.TOOL: ToolNodeOTelParser(), + BuiltinNodeTypes.LLM: LLMNodeOTelParser(), + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(), } def _get_parser(self, node: Node) -> NodeOTelParser: - node_type = getattr(node, "node_type", None) - if isinstance(node_type, NodeType): - return self._parsers.get(node_type, self._default_parser) - return self._default_parser + return self._parsers.get(node.node_type, self._default_parser) @override def on_graph_start(self) -> None: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 15cd319750..4fa941ae16 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import WorkflowNodeExecutionMetadataKey from dify_graph.file import File from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory from models.model import UploadFile from models.tools import ToolFile diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 19111cc917..18f35b5b9c 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -58,7 +58,7 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import WorkflowNodeExecutionTriggeredFrom @@ -302,11 +302,11 @@ class AliyunDataTrace(BaseTraceInstance): self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata ): try: - if node_execution.node_type == NodeType.LLM: + if node_execution.node_type == BuiltinNodeTypes.LLM: node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata) - elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL: + elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata) - elif node_execution.node_type == NodeType.TOOL: + elif node_execution.node_type == BuiltinNodeTypes.TOOL: node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata) else: node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata) diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 452255f69e..7cb54b2c88 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -155,8 +155,8 @@ def wrap_span_metadata(metadata, **kwargs): return metadata -# Mapping from NodeType string values to OpenInference span kinds. -# NodeType values not listed here default to CHAIN. +# Mapping from built-in node type strings to OpenInference span kinds. +# Node types not listed here default to CHAIN. _NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = { "llm": OpenInferenceSpanKindValues.LLM, "knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER, @@ -168,7 +168,7 @@ _NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = { def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues: """Return the OpenInference span kind for a given workflow node type. - Covers every ``NodeType`` enum value. Nodes that do not have a + Covers every built-in node type string. Nodes that do not have a specialised span kind (e.g. ``start``, ``end``, ``if-else``, ``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``. """ diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 28e800e6c7..6e62387a1f 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -141,7 +141,7 @@ class LangFuseDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == NodeType.LLM: + if node_type == BuiltinNodeTypes.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs or {} diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index b40bc89b71..32a0c77fe2 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -163,7 +163,7 @@ class LangSmithDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == NodeType.LLM: + if node_type == BuiltinNodeTypes.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs or {} @@ -197,7 +197,7 @@ class LangSmithDataTrace(BaseTraceInstance): "ls_model_name": process_data.get("model_name", ""), } ) - elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: + elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: run_type = LangSmithRunType.retriever else: run_type = LangSmithRunType.tool diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index ba2cb9e0c3..ab4a7650ec 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -145,10 +145,10 @@ class MLflowDataTrace(BaseTraceInstance): "app_name": node.title, } - if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER): + if node.node_type in (BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER): inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node) attributes.update(llm_attributes) - elif node.node_type == NodeType.HTTP_REQUEST: + elif node.node_type == BuiltinNodeTypes.HTTP_REQUEST: inputs = node.process_data # contains request URL if not inputs: @@ -180,9 +180,9 @@ class MLflowDataTrace(BaseTraceInstance): # End node span finished_at = node.created_at + timedelta(seconds=node.elapsed_time) outputs = json.loads(node.outputs) if node.outputs else {} - if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL: + if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: outputs = self._parse_knowledge_retrieval_outputs(outputs) - elif node.node_type == NodeType.LLM: + elif node.node_type == BuiltinNodeTypes.LLM: outputs = outputs.get("text", outputs) node_span.end( outputs=outputs, @@ -471,13 +471,13 @@ class MLflowDataTrace(BaseTraceInstance): def _get_node_span_type(self, node_type: str) -> str: """Map Dify node types to MLflow span types""" node_type_mapping = { - NodeType.LLM: SpanType.LLM, - NodeType.QUESTION_CLASSIFIER: SpanType.LLM, - NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER, - NodeType.TOOL: SpanType.TOOL, - NodeType.CODE: SpanType.TOOL, - NodeType.HTTP_REQUEST: SpanType.TOOL, - NodeType.AGENT: SpanType.AGENT, + BuiltinNodeTypes.LLM: SpanType.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER: SpanType.LLM, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER, + BuiltinNodeTypes.TOOL: SpanType.TOOL, + BuiltinNodeTypes.CODE: SpanType.TOOL, + BuiltinNodeTypes.HTTP_REQUEST: SpanType.TOOL, + BuiltinNodeTypes.AGENT: SpanType.AGENT, } return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload] diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index eab51fd9f8..fb72bc2381 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -187,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == NodeType.LLM: + if node_type == BuiltinNodeTypes.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs or {} diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index cbff1c9e1c..7e56b1effa 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -27,7 +27,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from extensions.ext_database import db from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom @@ -179,7 +179,7 @@ class TencentDataTrace(BaseTraceInstance): if node_span: self.trace_client.add_span(node_span) - if node_execution.node_type == NodeType.LLM: + if node_execution.node_type == BuiltinNodeTypes.LLM: self._record_llm_metrics(node_execution) except Exception: logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id) @@ -192,15 +192,15 @@ class TencentDataTrace(BaseTraceInstance): ) -> SpanData | None: """Build span for different node types""" try: - if node_execution.node_type == NodeType.LLM: + if node_execution.node_type == BuiltinNodeTypes.LLM: return TencentSpanBuilder.build_workflow_llm_span( trace_id, workflow_span_id, trace_info, node_execution ) - elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL: + elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: return TencentSpanBuilder.build_workflow_retrieval_span( trace_id, workflow_span_id, trace_info, node_execution ) - elif node_execution.node_type == NodeType.TOOL: + elif node_execution.node_type == BuiltinNodeTypes.TOOL: return TencentSpanBuilder.build_workflow_tool_span( trace_id, workflow_span_id, trace_info, node_execution ) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 7b62207366..2a657b672c 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -31,7 +31,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -175,7 +175,7 @@ class WeaveDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == NodeType.LLM: + if node_type == BuiltinNodeTypes.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs or {} diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 33c45c0007..d6aef93fc4 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,5 +1,5 @@ from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) @@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): instruction=instruction, # instruct with variables are not supported ) node_data_dict = node_data.model_dump() - node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR + node_data_dict["type"] = BuiltinNodeTypes.PARAMETER_EXTRACTOR execution = workflow_service.run_free_workflow_node( node_data_dict, tenant_id=tenant_id, diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index c8f9d29012..a7c42c5a4e 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,8 +9,8 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory -from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError -from dify_graph.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview +from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError +from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview from models.dataset import Dataset, Document, DocumentSegment from .index_processor_factory import IndexProcessorFactory diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index fcd3cceb59..4c96b63f25 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -56,18 +56,18 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.knowledge_retrieval import exc -from dify_graph.repositories.rag_retrieval_protocol import ( +from core.workflow.nodes.knowledge_retrieval import exc +from core.workflow.nodes.knowledge_retrieval.retrieval import ( KnowledgeRetrievalRequest, Source, SourceChildChunk, SourceMetadata, ) +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.json_in_md_parser import parse_and_check_json_markdown diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 3fc333038d..7373ebc7cc 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -18,7 +18,7 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter @@ -146,7 +146,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) index=db_model.index, predecessor_node_id=db_model.predecessor_node_id, node_id=db_model.node_id, - node_type=NodeType(db_model.node_type), + node_type=db_model.node_type, title=db_model.title, inputs=inputs, process_data=process_data, diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index d8ce53083b..28f1376655 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -3,7 +3,7 @@ from typing import Any from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.base.entities import OutputVariableEntity from dify_graph.variables.input_entities import VariableEntity @@ -51,7 +51,7 @@ class WorkflowToolConfigurationUtils: def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None: nodes = graph.get("nodes", []) for node in nodes: - if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT: + if node.get("data", {}).get("type") == BuiltinNodeTypes.HUMAN_INPUT: raise WorkflowToolHumanInputNotSupportedError() @classmethod diff --git a/api/core/trigger/constants.py b/api/core/trigger/constants.py new file mode 100644 index 0000000000..bfa45c3f2b --- /dev/null +++ b/api/core/trigger/constants.py @@ -0,0 +1,18 @@ +from typing import Final + +TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook" +TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule" +TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin" +TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info" + +TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset( + { + TRIGGER_WEBHOOK_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_PLUGIN_NODE_TYPE, + } +) + + +def is_trigger_node_type(node_type: str) -> bool: + return node_type in TRIGGER_NODE_TYPES diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 442a2434d5..2a133b2b94 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -11,6 +11,11 @@ from typing import Any from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.events import ( PluginTriggerDebugEvent, @@ -19,10 +24,9 @@ from core.trigger.debug.events import ( build_plugin_pool_key, build_webhook_pool_key, ) +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData +from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType -from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at @@ -206,21 +210,19 @@ def create_event_poller( if not node_config: raise ValueError("Node data not found for node %s", node_id) node_type = draft_workflow.get_node_type_from_node_config(node_config) - match node_type: - case NodeType.TRIGGER_PLUGIN: - return PluginTriggerDebugEventPoller( - tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id - ) - case NodeType.TRIGGER_WEBHOOK: - return WebhookTriggerDebugEventPoller( - tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id - ) - case NodeType.TRIGGER_SCHEDULE: - return ScheduleTriggerDebugEventPoller( - tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id - ) - case _: - raise ValueError("unable to create event poller for node type %s", node_type) + if node_type == TRIGGER_PLUGIN_NODE_TYPE: + return PluginTriggerDebugEventPoller( + tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id + ) + if node_type == TRIGGER_WEBHOOK_NODE_TYPE: + return WebhookTriggerDebugEventPoller( + tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id + ) + if node_type == TRIGGER_SCHEDULE_NODE_TYPE: + return ScheduleTriggerDebugEventPoller( + tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id + ) + raise ValueError("unable to create event poller for node type %s", node_type) def select_trigger_debug_events( diff --git a/api/core/workflow/__init__.py b/api/core/workflow/__init__.py index 57c2ef3d10..937012dcee 100644 --- a/api/core/workflow/__init__.py +++ b/api/core/workflow/__init__.py @@ -1,4 +1 @@ -from .node_factory import DifyNodeFactory -from .workflow_entry import WorkflowEntry - -__all__ = ["DifyNodeFactory", "WorkflowEntry"] +"""Core workflow package.""" diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index bc4e0eda71..ee3b322636 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -1,4 +1,7 @@ -from collections.abc import Callable, Mapping +import importlib +import pkgutil +from collections.abc import Callable, Iterator, Mapping, MutableMapping +from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeAlias, cast, final from sqlalchemy import select @@ -8,7 +11,6 @@ from typing_extensions import override from configs import dify_config from core.app.entities.app_invoke_entities import DifyRunContext from core.app.llm.model_access import build_dify_model_access -from core.datasource.datasource_manager import DatasourceManager from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -17,12 +19,9 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.rag.index_processor.index_processor import IndexProcessor -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.rag.summary_index.summary_index import SummaryIndex from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager -from core.workflow.node_resolution import resolve_workflow_node_class +from core.trigger.constants import TRIGGER_NODE_TYPES from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyPresentationProvider, @@ -32,7 +31,7 @@ from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey from dify_graph.file.file_manager import file_manager from dify_graph.graph.graph import NodeFactory from dify_graph.model_runtime.entities.model_entities import ModelType @@ -59,6 +58,135 @@ if TYPE_CHECKING: from dify_graph.entities import GraphInitParams from dify_graph.runtime import GraphRuntimeState +LATEST_VERSION = "latest" +_START_NODE_TYPES: frozenset[NodeType] = frozenset( + (BuiltinNodeTypes.START, BuiltinNodeTypes.DATASOURCE, *TRIGGER_NODE_TYPES) +) + + +def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None: + package = importlib.import_module(package_name) + for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."): + if module_name in excluded_modules: + continue + importlib.import_module(module_name) + + +@lru_cache(maxsize=1) +def register_nodes() -> None: + """Import production node modules so they self-register with ``Node``.""" + _import_node_package("dify_graph.nodes") + _import_node_package("core.workflow.nodes") + + +def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]: + """Return a read-only snapshot of the current production node registry. + + The workflow layer owns node bootstrap because it must compose built-in + `dify_graph.nodes.*` implementations with workflow-local nodes under + `core.workflow.nodes.*`. Keeping this import side effect here avoids + reintroducing registry bootstrapping into lower-level graph primitives. + """ + register_nodes() + return Node.get_node_type_classes_mapping() + + +def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + node_mapping = get_node_type_classes_mapping().get(node_type) + if not node_mapping: + raise ValueError(f"No class mapping found for node type: {node_type}") + + latest_node_class = node_mapping.get(LATEST_VERSION) + matched_node_class = node_mapping.get(node_version) + node_class = matched_node_class or latest_node_class + if not node_class: + raise ValueError(f"No latest version class found for node type: {node_type}") + return node_class + + +def is_start_node_type(node_type: NodeType) -> bool: + """Return True when the node type can serve as a workflow entry point.""" + return node_type in _START_NODE_TYPES + + +def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str: + """Resolve the default entry node for a persisted top-level workflow graph. + + This workflow-layer helper depends on start-node semantics defined by + `is_start_node_type`, so it intentionally lives next to the node registry + instead of in the raw `dify_graph.entities.graph_config` schema module. + """ + nodes = graph_config.get("nodes") + if not isinstance(nodes, list): + raise ValueError("nodes in workflow graph must be a list") + + for node in nodes: + if not isinstance(node, Mapping): + continue + + if node.get("type") == "custom-note": + continue + + node_id = node.get("id") + data = node.get("data") + if not isinstance(node_id, str) or not isinstance(data, Mapping): + continue + + node_type = data.get("type") + if isinstance(node_type, str) and is_start_node_type(node_type): + return node_id + + raise ValueError("Unable to determine default root node ID from workflow graph") + + +class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Node]]]): + """Mutable dict-like view over the current node registry.""" + + def __init__(self) -> None: + self._cached_snapshot: dict[NodeType, Mapping[str, type[Node]]] = {} + self._cached_version = -1 + self._deleted: set[NodeType] = set() + self._overrides: dict[NodeType, Mapping[str, type[Node]]] = {} + + def _snapshot(self) -> dict[NodeType, Mapping[str, type[Node]]]: + current_version = Node.get_registry_version() + if self._cached_version != current_version: + self._cached_snapshot = dict(get_node_type_classes_mapping()) + self._cached_version = current_version + if not self._deleted and not self._overrides: + return self._cached_snapshot + + snapshot = {key: value for key, value in self._cached_snapshot.items() if key not in self._deleted} + snapshot.update(self._overrides) + return snapshot + + def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]: + return self._snapshot()[key] + + def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None: + self._deleted.discard(key) + self._overrides[key] = value + + def __delitem__(self, key: NodeType) -> None: + if key in self._overrides: + del self._overrides[key] + return + if key in self._cached_snapshot: + self._deleted.add(key) + return + raise KeyError(key) + + def __iter__(self) -> Iterator[NodeType]: + return iter(self._snapshot()) + + def __len__(self) -> int: + return len(self._snapshot()) + + +# Keep the canonical node-class mapping in the workflow layer that also bootstraps +# legacy `core.workflow.nodes.*` registrations. +NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping() + LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData @@ -130,7 +258,6 @@ class DifyNodeFactory(NodeFactory): self._http_request_http_client = ssrf_proxy self._http_request_tool_file_manager_factory = ToolFileManager self._http_request_file_manager = file_manager - self._rag_retrieval = DatasetRetrieval() self._document_extractor_unstructured_api_config = UnstructuredApiConfig( api_url=dify_config.UNSTRUCTURED_API_URL, api_key=dify_config.UNSTRUCTURED_API_KEY or "", @@ -177,56 +304,46 @@ class DifyNodeFactory(NodeFactory): node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { - NodeType.CODE: lambda: { + BuiltinNodeTypes.CODE: lambda: { "code_executor": self._code_executor, "code_limits": self._code_limits, }, - NodeType.TEMPLATE_TRANSFORM: lambda: { + BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: { "template_renderer": self._template_renderer, "max_output_length": self._template_transform_max_output_length, }, - NodeType.HTTP_REQUEST: lambda: { + BuiltinNodeTypes.HTTP_REQUEST: lambda: { "http_request_config": self._http_request_config, "http_client": self._http_request_http_client, "tool_file_manager_factory": self._http_request_tool_file_manager_factory, "file_manager": self._http_request_file_manager, }, - NodeType.HUMAN_INPUT: lambda: { + BuiltinNodeTypes.HUMAN_INPUT: lambda: { "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), }, - NodeType.KNOWLEDGE_INDEX: lambda: { - "index_processor": IndexProcessor(), - "summary_index_service": SummaryIndex(), - }, - NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs( + BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, include_http_client=True, ), - NodeType.DATASOURCE: lambda: { - "datasource_manager": DatasourceManager, - }, - NodeType.KNOWLEDGE_RETRIEVAL: lambda: { - "rag_retrieval": self._rag_retrieval, - }, - NodeType.DOCUMENT_EXTRACTOR: lambda: { + BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { "unstructured_api_config": self._document_extractor_unstructured_api_config, "http_client": self._http_request_http_client, }, - NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( + BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, include_http_client=True, ), - NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( + BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, include_http_client=False, ), - NodeType.TOOL: lambda: { + BuiltinNodeTypes.TOOL: lambda: { "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), }, - NodeType.AGENT: lambda: { + BuiltinNodeTypes.AGENT: lambda: { "strategy_resolver": self._agent_strategy_resolver, "presentation_provider": self._agent_strategy_presentation_provider, "runtime_support": self._agent_runtime_support, diff --git a/api/core/workflow/node_resolution.py b/api/core/workflow/node_resolution.py deleted file mode 100644 index b922c28165..0000000000 --- a/api/core/workflow/node_resolution.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from importlib import import_module - -from dify_graph.enums import NodeType -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.node_mapping import LATEST_VERSION, get_node_type_classes_mapping - -_WORKFLOW_NODE_MODULES = ("core.workflow.nodes.agent",) -_workflow_nodes_registered = False - - -def ensure_workflow_nodes_registered() -> None: - """Import workflow-local node modules so they can register with `Node.__init_subclass__`.""" - global _workflow_nodes_registered - - if _workflow_nodes_registered: - return - - for module_name in _WORKFLOW_NODE_MODULES: - import_module(module_name) - - _workflow_nodes_registered = True - - -def get_workflow_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]: - ensure_workflow_nodes_registered() - return get_node_type_classes_mapping() - - -def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: - node_mapping = get_workflow_node_type_classes_mapping().get(node_type) - if not node_mapping: - raise ValueError(f"No class mapping found for node type: {node_type}") - - latest_node_class = node_mapping.get(LATEST_VERSION) - matched_node_class = node_mapping.get(node_version) - node_class = matched_node_class or latest_node_class - if not node_class: - raise ValueError(f"No latest version class found for node type: {node_type}") - return node_class diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py index e69de29bb2..d23f80be59 100644 --- a/api/core/workflow/nodes/__init__.py +++ b/api/core/workflow/nodes/__init__.py @@ -0,0 +1 @@ +"""Workflow node implementations that remain under the legacy core.workflow namespace.""" diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index c1b423d69d..5699ccf404 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType, SystemVariableKey, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser @@ -24,7 +24,7 @@ if TYPE_CHECKING: class AgentNode(Node[AgentNodeData]): - node_type = NodeType.AGENT + node_type = BuiltinNodeTypes.AGENT _strategy_resolver: AgentStrategyResolver _presentation_provider: AgentStrategyPresentationProvider diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 59842862ef..91fed39795 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -6,11 +6,11 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): - type: NodeType = NodeType.AGENT + type: NodeType = BuiltinNodeTypes.AGENT agent_strategy_provider_name: str agent_strategy_name: str agent_strategy_label: str diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index 317db14d3f..f58a5665f4 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from dify_graph.model_runtime.utils.encoders import jsonable_encoder @@ -123,7 +123,7 @@ class AgentMessageTransformer: ) elif message.type == ToolInvokeMessage.MessageType.JSON: assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - if node_type == NodeType.AGENT: + if node_type == BuiltinNodeTypes.AGENT: if isinstance(message.message.json_object, dict): msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) diff --git a/api/core/workflow/nodes/datasource/__init__.py b/api/core/workflow/nodes/datasource/__init__.py new file mode 100644 index 0000000000..2e9bed5e00 --- /dev/null +++ b/api/core/workflow/nodes/datasource/__init__.py @@ -0,0 +1 @@ +"""Datasource workflow node package.""" diff --git a/api/dify_graph/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py similarity index 94% rename from api/dify_graph/nodes/datasource/datasource_node.py rename to api/core/workflow/nodes/datasource/datasource_node.py index 62dcb2924f..44f4a23a5a 100644 --- a/api/dify_graph/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,22 +1,17 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey from dify_graph.node_events import NodeRunResult, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.repositories.datasource_manager_protocol import ( - DatasourceManagerProtocol, - DatasourceParameter, - OnlineDriveDownloadFileParam, -) -from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from .entities import DatasourceNodeData +from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError if TYPE_CHECKING: @@ -29,7 +24,7 @@ class DatasourceNode(Node[DatasourceNodeData]): Datasource Node """ - node_type = NodeType.DATASOURCE + node_type = BuiltinNodeTypes.DATASOURCE execution_type = NodeExecutionType.ROOT def __init__( @@ -38,7 +33,6 @@ class DatasourceNode(Node[DatasourceNodeData]): config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - datasource_manager: DatasourceManagerProtocol, ): super().__init__( id=id, @@ -46,7 +40,7 @@ class DatasourceNode(Node[DatasourceNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self.datasource_manager = datasource_manager + self.datasource_manager = DatasourceManager def populate_start_event(self, event) -> None: event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}" diff --git a/api/dify_graph/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py similarity index 85% rename from api/dify_graph/nodes/datasource/entities.py rename to api/core/workflow/nodes/datasource/entities.py index 38275ac158..65864474b0 100644 --- a/api/dify_graph/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType class DatasourceEntity(BaseModel): @@ -17,7 +17,7 @@ class DatasourceEntity(BaseModel): class DatasourceNodeData(BaseNodeData, DatasourceEntity): - type: NodeType = NodeType.DATASOURCE + type: NodeType = BuiltinNodeTypes.DATASOURCE class DatasourceInput(BaseModel): # TODO: check this type @@ -42,3 +42,14 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity): return typ datasource_parameters: dict[str, DatasourceInput] | None = None + + +class DatasourceParameter(BaseModel): + workspace_id: str + page_id: str + type: str + + +class OnlineDriveDownloadFileParam(BaseModel): + id: str + bucket: str diff --git a/api/dify_graph/nodes/datasource/exc.py b/api/core/workflow/nodes/datasource/exc.py similarity index 100% rename from api/dify_graph/nodes/datasource/exc.py rename to api/core/workflow/nodes/datasource/exc.py diff --git a/api/dify_graph/repositories/datasource_manager_protocol.py b/api/core/workflow/nodes/datasource/protocols.py similarity index 79% rename from api/dify_graph/repositories/datasource_manager_protocol.py rename to api/core/workflow/nodes/datasource/protocols.py index fbe2016d3c..c006e0885c 100644 --- a/api/dify_graph/repositories/datasource_manager_protocol.py +++ b/api/core/workflow/nodes/datasource/protocols.py @@ -1,25 +1,10 @@ from collections.abc import Generator from typing import Any, Protocol -from pydantic import BaseModel - from dify_graph.file import File from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent - -class DatasourceParameter(BaseModel): - workspace_id: str - page_id: str - type: str - - -class OnlineDriveDownloadFileParam(BaseModel): - id: str - bucket: str - - -class DatasourceFinal(BaseModel): - data: dict[str, Any] | None = None +from .entities import DatasourceParameter, OnlineDriveDownloadFileParam class DatasourceManagerProtocol(Protocol): diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/core/workflow/nodes/knowledge_index/__init__.py new file mode 100644 index 0000000000..efc6a57b3d --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/__init__.py @@ -0,0 +1,5 @@ +"""Knowledge index workflow node package.""" + +KNOWLEDGE_INDEX_NODE_TYPE = "knowledge-index" + +__all__ = ["KNOWLEDGE_INDEX_NODE_TYPE"] diff --git a/api/dify_graph/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py similarity index 96% rename from api/dify_graph/nodes/knowledge_index/entities.py rename to api/core/workflow/nodes/knowledge_index/entities.py index d88ee8e3af..8b00746268 100644 --- a/api/dify_graph/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -3,6 +3,7 @@ from typing import Literal, Union from pydantic import BaseModel from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import NodeType @@ -156,7 +157,7 @@ class KnowledgeIndexNodeData(BaseNodeData): Knowledge index Node Data. """ - type: NodeType = NodeType.KNOWLEDGE_INDEX + type: NodeType = KNOWLEDGE_INDEX_NODE_TYPE chunk_structure: str index_chunk_variable_selector: list[str] indexing_technique: str | None = None diff --git a/api/dify_graph/nodes/knowledge_index/exc.py b/api/core/workflow/nodes/knowledge_index/exc.py similarity index 100% rename from api/dify_graph/nodes/knowledge_index/exc.py rename to api/core/workflow/nodes/knowledge_index/exc.py diff --git a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py similarity index 91% rename from api/dify_graph/nodes/knowledge_index/knowledge_index_node.py rename to api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 3c4fe2344c..0a74847bc1 100644 --- a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,14 +2,15 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from core.rag.index_processor.index_processor import IndexProcessor +from core.rag.summary_index.summary_index import SummaryIndex +from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey +from dify_graph.enums import NodeExecutionType, SystemVariableKey from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.template import Template -from dify_graph.repositories.index_processor_protocol import IndexProcessorProtocol -from dify_graph.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol from .entities import KnowledgeIndexNodeData from .exc import ( @@ -25,7 +26,7 @@ _INVOKE_FROM_DEBUGGER = "debugger" class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): - node_type = NodeType.KNOWLEDGE_INDEX + node_type = KNOWLEDGE_INDEX_NODE_TYPE execution_type = NodeExecutionType.RESPONSE def __init__( @@ -34,12 +35,10 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - index_processor: IndexProcessorProtocol, - summary_index_service: SummaryIndexServiceProtocol, ) -> None: super().__init__(id, config, graph_init_params, graph_runtime_state) - self.index_processor = index_processor - self.summary_index_service = summary_index_service + self.index_processor = IndexProcessor() + self.summary_index_service = SummaryIndex() def _run(self) -> NodeRunResult: # type: ignore node_data = self.node_data diff --git a/api/dify_graph/repositories/index_processor_protocol.py b/api/core/workflow/nodes/knowledge_index/protocols.py similarity index 55% rename from api/dify_graph/repositories/index_processor_protocol.py rename to api/core/workflow/nodes/knowledge_index/protocols.py index feaa4ab5de..bb52123082 100644 --- a/api/dify_graph/repositories/index_processor_protocol.py +++ b/api/core/workflow/nodes/knowledge_index/protocols.py @@ -5,21 +5,21 @@ from pydantic import BaseModel, Field class PreviewItem(BaseModel): - content: str | None = Field(None) - child_chunks: list[str] | None = Field(None) - summary: str | None = Field(None) + content: str | None = Field(default=None) + child_chunks: list[str] | None = Field(default=None) + summary: str | None = Field(default=None) class QaPreview(BaseModel): - answer: str | None = Field(None) - question: str | None = Field(None) + answer: str | None = Field(default=None) + question: str | None = Field(default=None) class Preview(BaseModel): chunk_structure: str - parent_mode: str | None = Field(None) - preview: list[PreviewItem] = Field([]) - qa_preview: list[QaPreview] = Field([]) + parent_mode: str | None = Field(default=None) + preview: list[PreviewItem] = Field(default_factory=list) + qa_preview: list[QaPreview] = Field(default_factory=list) total_segments: int @@ -39,3 +39,9 @@ class IndexProcessorProtocol(Protocol): def get_preview_output( self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None ) -> Preview: ... + + +class SummaryIndexServiceProtocol(Protocol): + def generate_and_vectorize_summary( + self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None + ) -> None: ... diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/core/workflow/nodes/knowledge_retrieval/__init__.py new file mode 100644 index 0000000000..33ea4277b4 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/__init__.py @@ -0,0 +1 @@ +"""Knowledge retrieval workflow node package.""" diff --git a/api/dify_graph/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py similarity index 96% rename from api/dify_graph/nodes/knowledge_retrieval/entities.py rename to api/core/workflow/nodes/knowledge_retrieval/entities.py index 8f226b9785..bc5618685a 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -4,7 +4,7 @@ from typing import Literal from pydantic import BaseModel, Field from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig @@ -114,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): Knowledge retrieval Node Data. """ - type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL + type: NodeType = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL query_variable_selector: list[str] | None | str = None query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] diff --git a/api/dify_graph/nodes/knowledge_retrieval/exc.py b/api/core/workflow/nodes/knowledge_retrieval/exc.py similarity index 100% rename from api/dify_graph/nodes/knowledge_retrieval/exc.py rename to api/core/workflow/nodes/knowledge_retrieval/exc.py diff --git a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py similarity index 97% rename from api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py rename to api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 61c9614340..9c3b9aacbf 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,12 +1,19 @@ +"""Knowledge retrieval workflow node implementation. + +This node now lives under ``core.workflow.nodes`` and is discovered directly by +the workflow node registry. +""" + import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( - NodeType, + BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -15,7 +22,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node -from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source from dify_graph.variables import ( ArrayFileSegment, FileSegment, @@ -32,6 +38,7 @@ from .exc import ( KnowledgeRetrievalNodeError, RateLimitExceededError, ) +from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: from dify_graph.file.models import File @@ -41,7 +48,7 @@ logger = logging.getLogger(__name__) class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): - node_type = NodeType.KNOWLEDGE_RETRIEVAL + node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL # Instance attributes specific to LLMNode. # Output variable for file @@ -53,7 +60,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - rag_retrieval: RAGRetrievalProtocol, ): super().__init__( id=id, @@ -63,7 +69,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._rag_retrieval = rag_retrieval + self._rag_retrieval = DatasetRetrieval() @classmethod def version(cls): diff --git a/api/dify_graph/repositories/rag_retrieval_protocol.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py similarity index 83% rename from api/dify_graph/repositories/rag_retrieval_protocol.py rename to api/core/workflow/nodes/knowledge_retrieval/retrieval.py index 5f3d38167e..f964f79582 100644 --- a/api/dify_graph/repositories/rag_retrieval_protocol.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -3,9 +3,10 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, Field from dify_graph.model_runtime.entities import LLMUsage -from dify_graph.nodes.knowledge_retrieval.entities import MetadataFilteringCondition from dify_graph.nodes.llm.entities import ModelConfig +from .entities import MetadataFilteringCondition + class SourceChildChunk(BaseModel): id: str = Field(default="", description="Child chunk ID") @@ -28,7 +29,7 @@ class SourceMetadata(BaseModel): segment_id: str | None = Field(default=None, description="Segment unique identifier") retriever_from: str = Field(default="workflow", description="Retriever source context") score: float = Field(default=0.0, description="Retrieval relevance score") - child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks") + child_chunks: list[SourceChildChunk] = Field(default_factory=list, description="List of child chunks") segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved") segment_word_count: int | None = Field(default=0, description="Word count of the segment") segment_position: int | None = Field(default=0, description="Position of segment in document") @@ -81,28 +82,7 @@ class KnowledgeRetrievalRequest(BaseModel): class RAGRetrievalProtocol(Protocol): - """Protocol for RAG-based knowledge retrieval implementations. - - Implementations of this protocol handle knowledge retrieval from datasets - including rate limiting, dataset filtering, and document retrieval. - """ - @property - def llm_usage(self) -> LLMUsage: - """Return accumulated LLM usage for retrieval operations.""" - ... + def llm_usage(self) -> LLMUsage: ... - def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]: - """Retrieve knowledge from datasets based on the provided request. - - Args: - request: Knowledge retrieval request with search parameters - - Returns: - List of sources matching the search criteria - - Raises: - RateLimitExceededError: If rate limit is exceeded - ModelNotExistError: If specified model doesn't exist - """ - ... + def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]: ... diff --git a/api/dify_graph/nodes/knowledge_retrieval/template_prompts.py b/api/core/workflow/nodes/knowledge_retrieval/template_prompts.py similarity index 100% rename from api/dify_graph/nodes/knowledge_retrieval/template_prompts.py rename to api/core/workflow/nodes/knowledge_retrieval/template_prompts.py diff --git a/api/dify_graph/nodes/trigger_plugin/__init__.py b/api/core/workflow/nodes/trigger_plugin/__init__.py similarity index 100% rename from api/dify_graph/nodes/trigger_plugin/__init__.py rename to api/core/workflow/nodes/trigger_plugin/__init__.py diff --git a/api/dify_graph/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py similarity index 95% rename from api/dify_graph/nodes/trigger_plugin/entities.py rename to api/core/workflow/nodes/trigger_plugin/entities.py index 33a61c9bc8..ea7d20befe 100644 --- a/api/dify_graph/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -3,16 +3,18 @@ from typing import Any, Literal, Union from pydantic import BaseModel, Field, ValidationInfo, field_validator +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import NodeType -from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError + +from .exc import TriggerEventParameterError class TriggerEventNodeData(BaseNodeData): """Plugin trigger node data""" - type: NodeType = NodeType.TRIGGER_PLUGIN + type: NodeType = TRIGGER_PLUGIN_NODE_TYPE class TriggerEventInput(BaseModel): value: Union[Any, list[str]] diff --git a/api/dify_graph/nodes/trigger_plugin/exc.py b/api/core/workflow/nodes/trigger_plugin/exc.py similarity index 100% rename from api/dify_graph/nodes/trigger_plugin/exc.py rename to api/core/workflow/nodes/trigger_plugin/exc.py diff --git a/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py similarity index 84% rename from api/dify_graph/nodes/trigger_plugin/trigger_event_node.py rename to api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index 536ba96dec..2048a53064 100644 --- a/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,8 +1,10 @@ from collections.abc import Mapping +from typing import Any, cast +from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, NodeType +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -10,7 +12,7 @@ from .entities import TriggerEventNodeData class TriggerEventNode(Node[TriggerEventNodeData]): - node_type = NodeType.TRIGGER_PLUGIN + node_type = TRIGGER_PLUGIN_NODE_TYPE execution_type = NodeExecutionType.ROOT @classmethod @@ -44,8 +46,8 @@ class TriggerEventNode(Node[TriggerEventNodeData]): """ # Get trigger data passed when workflow was triggered - metadata = { - WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { + metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { + cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): { "provider_id": self.node_data.provider_id, "event_name": self.node_data.event_name, "plugin_unique_identifier": self.node_data.plugin_unique_identifier, diff --git a/api/core/workflow/nodes/trigger_schedule/__init__.py b/api/core/workflow/nodes/trigger_schedule/__init__.py new file mode 100644 index 0000000000..07b711a0fd --- /dev/null +++ b/api/core/workflow/nodes/trigger_schedule/__init__.py @@ -0,0 +1,3 @@ +from .trigger_schedule_node import TriggerScheduleNode + +__all__ = ["TriggerScheduleNode"] diff --git a/api/dify_graph/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py similarity index 94% rename from api/dify_graph/nodes/trigger_schedule/entities.py rename to api/core/workflow/nodes/trigger_schedule/entities.py index 2b0edcabba..95a2548678 100644 --- a/api/dify_graph/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -2,6 +2,7 @@ from typing import Literal, Union from pydantic import BaseModel, Field +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import NodeType @@ -11,7 +12,7 @@ class TriggerScheduleNodeData(BaseNodeData): Trigger Schedule Node Data """ - type: NodeType = NodeType.TRIGGER_SCHEDULE + type: NodeType = TRIGGER_SCHEDULE_NODE_TYPE mode: str = Field(default="visual", description="Schedule mode: visual or cron") frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly") cron_expression: str | None = Field(default=None, description="Cron expression for cron mode") diff --git a/api/dify_graph/nodes/trigger_schedule/exc.py b/api/core/workflow/nodes/trigger_schedule/exc.py similarity index 100% rename from api/dify_graph/nodes/trigger_schedule/exc.py rename to api/core/workflow/nodes/trigger_schedule/exc.py diff --git a/api/dify_graph/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py similarity index 85% rename from api/dify_graph/nodes/trigger_schedule/trigger_schedule_node.py rename to api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index 7e92eb3f4f..b9580e6ab1 100644 --- a/api/dify_graph/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,15 +1,17 @@ from collections.abc import Mapping +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, NodeType +from dify_graph.enums import NodeExecutionType from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node -from dify_graph.nodes.trigger_schedule.entities import TriggerScheduleNodeData + +from .entities import TriggerScheduleNodeData class TriggerScheduleNode(Node[TriggerScheduleNodeData]): - node_type = NodeType.TRIGGER_SCHEDULE + node_type = TRIGGER_SCHEDULE_NODE_TYPE execution_type = NodeExecutionType.ROOT @classmethod @@ -19,7 +21,7 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { - "type": "trigger-schedule", + "type": TRIGGER_SCHEDULE_NODE_TYPE, "config": { "mode": "visual", "frequency": "daily", diff --git a/api/dify_graph/nodes/trigger_webhook/__init__.py b/api/core/workflow/nodes/trigger_webhook/__init__.py similarity index 100% rename from api/dify_graph/nodes/trigger_webhook/__init__.py rename to api/core/workflow/nodes/trigger_webhook/__init__.py diff --git a/api/dify_graph/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py similarity index 97% rename from api/dify_graph/nodes/trigger_webhook/entities.py rename to api/core/workflow/nodes/trigger_webhook/entities.py index a4f8745e71..242bf5ef6a 100644 --- a/api/dify_graph/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -3,6 +3,7 @@ from enum import StrEnum from pydantic import BaseModel, Field, field_validator +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import NodeType from dify_graph.variables.types import SegmentType @@ -93,7 +94,7 @@ class WebhookData(BaseNodeData): class SyncMode(StrEnum): SYNC = "async" # only support - type: NodeType = NodeType.TRIGGER_WEBHOOK + type: NodeType = TRIGGER_WEBHOOK_NODE_TYPE method: Method = Method.GET content_type: ContentType = Field(default=ContentType.JSON) headers: Sequence[WebhookParameter] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/trigger_webhook/exc.py b/api/core/workflow/nodes/trigger_webhook/exc.py similarity index 100% rename from api/dify_graph/nodes/trigger_webhook/exc.py rename to api/core/workflow/nodes/trigger_webhook/exc.py diff --git a/api/dify_graph/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py similarity index 97% rename from api/dify_graph/nodes/trigger_webhook/node.py rename to api/core/workflow/nodes/trigger_webhook/node.py index 413eda5272..317844cbda 100644 --- a/api/dify_graph/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,9 +2,10 @@ import logging from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, NodeType +from dify_graph.enums import NodeExecutionType from dify_graph.file import FileTransferMethod from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -19,7 +20,7 @@ logger = logging.getLogger(__name__) class TriggerWebhookNode(Node[WebhookData]): - node_type = NodeType.TRIGGER_WEBHOOK + node_type = TRIGGER_WEBHOOK_NODE_TYPE execution_type = NodeExecutionType.ROOT @classmethod diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 01b309bf54..2e51a06bab 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -8,8 +8,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.node_resolution import resolve_workflow_node_class +from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDictAdapter @@ -22,7 +21,7 @@ from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLay from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_engine.protocols.command_channel import CommandChannel from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -253,7 +252,7 @@ class WorkflowEntry: variable_mapping=variable_mapping, user_inputs=user_inputs, ) - if node_type != NodeType.DATASOURCE: + if node_type != BuiltinNodeTypes.DATASOURCE: cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, @@ -303,7 +302,7 @@ class WorkflowEntry: "height": node_height, "type": "custom", "data": { - "type": NodeType.START, + "type": BuiltinNodeTypes.START, "title": "Start", "desc": "Start", }, @@ -339,8 +338,8 @@ class WorkflowEntry: # Create a minimal graph for single node execution graph_dict = cls._create_single_node_graph(node_id, node_data) - node_type = NodeType(node_data.get("type", "")) - if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: + node_type = node_data.get("type", "") + if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}: raise ValueError(f"Node type {node_type} not supported") node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1") diff --git a/api/dify_graph/README.md b/api/dify_graph/README.md index 09c4f5afdc..2fc5b8b890 100644 --- a/api/dify_graph/README.md +++ b/api/dify_graph/README.md @@ -113,7 +113,7 @@ The codebase enforces strict layering via import-linter: 1. Create node class in `nodes//` 1. Inherit from `BaseNode` or appropriate base class 1. Implement `_run()` method -1. Register in `nodes/node_mapping.py` +1. Ensure the node module is importable under `nodes//` 1. Add tests in `tests/unit_tests/dify_graph/nodes/` ### Implementing a Custom Layer diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py index 58869a94c2..47b37c9daf 100644 --- a/api/dify_graph/entities/base_node_data.py +++ b/api/dify_graph/entities/base_node_data.py @@ -121,6 +121,8 @@ class DefaultValue(BaseModel): class BaseNodeData(ABC, BaseModel): # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. + # `type` therefore accepts downstream string node kinds; unknown node implementations + # are rejected later when the node factory resolves the node registry. # At that boundary, node-specific fields are still "extra" relative to this shared DTO, # and persisted templates/workflows also carry undeclared compatibility keys such as # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive diff --git a/api/dify_graph/entities/workflow_node_execution.py b/api/dify_graph/entities/workflow_node_execution.py index 9dd04e331b..bc7e0d02e5 100644 --- a/api/dify_graph/entities/workflow_node_execution.py +++ b/api/dify_graph/entities/workflow_node_execution.py @@ -48,7 +48,7 @@ class WorkflowNodeExecution(BaseModel): index: int # Sequence number for ordering in trace visualization predecessor_node_id: str | None = None # ID of the node that executed before this one node_id: str # ID of the node being executed - node_type: NodeType # Type of node (e.g., start, llm, knowledge) + node_type: NodeType # Type of node (e.g., start, llm, downstream response node) title: str # Display title of the node # Execution data diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py index bb3b13e8c6..06653bebb6 100644 --- a/api/dify_graph/enums.py +++ b/api/dify_graph/enums.py @@ -1,4 +1,5 @@ from enum import StrEnum +from typing import ClassVar, TypeAlias class NodeState(StrEnum): @@ -33,56 +34,71 @@ class SystemVariableKey(StrEnum): INVOKE_FROM = "invoke_from" -class NodeType(StrEnum): - START = "start" - END = "end" - ANSWER = "answer" - LLM = "llm" - KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" - KNOWLEDGE_INDEX = "knowledge-index" - IF_ELSE = "if-else" - CODE = "code" - TEMPLATE_TRANSFORM = "template-transform" - QUESTION_CLASSIFIER = "question-classifier" - HTTP_REQUEST = "http-request" - TOOL = "tool" - DATASOURCE = "datasource" - VARIABLE_AGGREGATOR = "variable-aggregator" - LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. - LOOP = "loop" - LOOP_START = "loop-start" - LOOP_END = "loop-end" - ITERATION = "iteration" - ITERATION_START = "iteration-start" # Fake start node for iteration. - PARAMETER_EXTRACTOR = "parameter-extractor" - VARIABLE_ASSIGNER = "assigner" - DOCUMENT_EXTRACTOR = "document-extractor" - LIST_OPERATOR = "list-operator" - AGENT = "agent" - TRIGGER_WEBHOOK = "trigger-webhook" - TRIGGER_SCHEDULE = "trigger-schedule" - TRIGGER_PLUGIN = "trigger-plugin" - HUMAN_INPUT = "human-input" +NodeType: TypeAlias = str - @property - def is_trigger_node(self) -> bool: - """Check if this node type is a trigger node.""" - return self in [ - NodeType.TRIGGER_WEBHOOK, - NodeType.TRIGGER_SCHEDULE, - NodeType.TRIGGER_PLUGIN, - ] - @property - def is_start_node(self) -> bool: - """Check if this node type can serve as a workflow entry point.""" - return self in [ - NodeType.START, - NodeType.DATASOURCE, - NodeType.TRIGGER_WEBHOOK, - NodeType.TRIGGER_SCHEDULE, - NodeType.TRIGGER_PLUGIN, - ] +class BuiltinNodeTypes: + """Built-in node type string constants. + + `node_type` values are plain strings throughout the graph runtime. This namespace + only exposes the built-in values shipped by `dify_graph`; downstream packages can + use additional strings without extending this class. + """ + + START: ClassVar[NodeType] = "start" + END: ClassVar[NodeType] = "end" + ANSWER: ClassVar[NodeType] = "answer" + LLM: ClassVar[NodeType] = "llm" + KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval" + IF_ELSE: ClassVar[NodeType] = "if-else" + CODE: ClassVar[NodeType] = "code" + TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform" + QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier" + HTTP_REQUEST: ClassVar[NodeType] = "http-request" + TOOL: ClassVar[NodeType] = "tool" + DATASOURCE: ClassVar[NodeType] = "datasource" + VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator" + LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner" + LOOP: ClassVar[NodeType] = "loop" + LOOP_START: ClassVar[NodeType] = "loop-start" + LOOP_END: ClassVar[NodeType] = "loop-end" + ITERATION: ClassVar[NodeType] = "iteration" + ITERATION_START: ClassVar[NodeType] = "iteration-start" + PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor" + VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner" + DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor" + LIST_OPERATOR: ClassVar[NodeType] = "list-operator" + AGENT: ClassVar[NodeType] = "agent" + HUMAN_INPUT: ClassVar[NodeType] = "human-input" + + +BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = ( + BuiltinNodeTypes.START, + BuiltinNodeTypes.END, + BuiltinNodeTypes.ANSWER, + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + BuiltinNodeTypes.IF_ELSE, + BuiltinNodeTypes.CODE, + BuiltinNodeTypes.TEMPLATE_TRANSFORM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.HTTP_REQUEST, + BuiltinNodeTypes.TOOL, + BuiltinNodeTypes.DATASOURCE, + BuiltinNodeTypes.VARIABLE_AGGREGATOR, + BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR, + BuiltinNodeTypes.LOOP, + BuiltinNodeTypes.LOOP_START, + BuiltinNodeTypes.LOOP_END, + BuiltinNodeTypes.ITERATION, + BuiltinNodeTypes.ITERATION_START, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + BuiltinNodeTypes.VARIABLE_ASSIGNER, + BuiltinNodeTypes.DOCUMENT_EXTRACTOR, + BuiltinNodeTypes.LIST_OPERATOR, + BuiltinNodeTypes.AGENT, + BuiltinNodeTypes.HUMAN_INPUT, +) class NodeExecutionType(StrEnum): @@ -236,7 +252,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): CURRENCY = "currency" TOOL_INFO = "tool_info" AGENT_LOG = "agent_log" - TRIGGER_INFO = "trigger_info" ITERATION_ID = "iteration_id" ITERATION_INDEX = "iteration_index" LOOP_ID = "loop_id" diff --git a/api/dify_graph/graph/graph.py b/api/dify_graph/graph/graph.py index 3eb6bfc359..85117583e0 100644 --- a/api/dify_graph/graph/graph.py +++ b/api/dify_graph/graph/graph.py @@ -83,50 +83,6 @@ class Graph: return node_configs_map - @classmethod - def _find_root_node_id( - cls, - node_configs_map: Mapping[str, NodeConfigDict], - edge_configs: Sequence[Mapping[str, object]], - root_node_id: str | None = None, - ) -> str: - """ - Find the root node ID if not specified. - - :param node_configs_map: mapping of node ID to node config - :param edge_configs: list of edge configurations - :param root_node_id: explicitly specified root node ID - :return: determined root node ID - """ - if root_node_id: - if root_node_id not in node_configs_map: - raise ValueError(f"Root node id {root_node_id} not found in the graph") - return root_node_id - - # Find nodes with no incoming edges - nodes_with_incoming: set[str] = set() - for edge_config in edge_configs: - target = edge_config.get("target") - if isinstance(target, str): - nodes_with_incoming.add(target) - - root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming] - - # Prefer START node if available - start_node_id = None - for nid in root_candidates: - node_data = node_configs_map[nid]["data"] - if node_data.type.is_start_node: - start_node_id = nid - break - - root_node_id = start_node_id or (root_candidates[0] if root_candidates else None) - - if not root_node_id: - raise ValueError("Unable to determine root node ID") - - return root_node_id - @classmethod def _build_edges( cls, edge_configs: list[dict[str, object]] @@ -301,15 +257,15 @@ class Graph: *, graph_config: Mapping[str, object], node_factory: NodeFactory, - root_node_id: str | None = None, + root_node_id: str, skip_validation: bool = False, ) -> Graph: """ - Initialize graph + Initialize a graph with an explicit execution entry point. :param graph_config: graph config containing nodes and edges :param node_factory: factory for creating node instances from config data - :param root_node_id: root node id + :param root_node_id: active root node id :return: graph instance """ # Parse configs @@ -327,8 +283,8 @@ class Graph: # Parse node configurations node_configs_map = cls._parse_node_configs(node_configs) - # Find root node - root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id) + if root_node_id not in node_configs_map: + raise ValueError(f"Root node id {root_node_id} not found in the graph") # Build edges edges, in_edges, out_edges = cls._build_edges(edge_configs) diff --git a/api/dify_graph/graph/validation.py b/api/dify_graph/graph/validation.py index 6840bcfed2..50d1440b04 100644 --- a/api/dify_graph/graph/validation.py +++ b/api/dify_graph/graph/validation.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol -from dify_graph.enums import NodeExecutionType, NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType if TYPE_CHECKING: from .graph import Graph @@ -71,7 +71,7 @@ class _RootNodeValidator: """Validates root node invariants.""" invalid_root_code: str = "INVALID_ROOT" - container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START) + container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START) def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: root_node = graph.root_node @@ -86,7 +86,7 @@ class _RootNodeValidator: ) return issues - node_type = getattr(root_node, "node_type", None) + node_type = root_node.node_type if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: issues.append( GraphValidationIssue( @@ -114,45 +114,9 @@ class GraphValidator: raise GraphValidationError(issues) -@dataclass(frozen=True, slots=True) -class _TriggerStartExclusivityValidator: - """Ensures trigger nodes do not coexist with UserInput (start) nodes.""" - - conflict_code: str = "TRIGGER_START_NODE_CONFLICT" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - start_node_id: str | None = None - trigger_node_ids: list[str] = [] - - for node in graph.nodes.values(): - node_type = getattr(node, "node_type", None) - if not isinstance(node_type, NodeType): - continue - - if node_type == NodeType.START: - start_node_id = node.id - elif node_type.is_trigger_node: - trigger_node_ids.append(node.id) - - if start_node_id and trigger_node_ids: - trigger_list = ", ".join(trigger_node_ids) - return [ - GraphValidationIssue( - code=self.conflict_code, - message=( - f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}." - ), - node_id=start_node_id, - ) - ] - - return [] - - _DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( _EdgeEndpointValidator(), _RootNodeValidator(), - _TriggerStartExclusivityValidator(), ) diff --git a/api/dify_graph/graph_engine/response_coordinator/__init__.py b/api/dify_graph/graph_engine/response_coordinator/__init__.py index e11d31199c..2a80d316e8 100644 --- a/api/dify_graph/graph_engine/response_coordinator/__init__.py +++ b/api/dify_graph/graph_engine/response_coordinator/__init__.py @@ -6,5 +6,6 @@ of responses based on upstream node outputs and constants. """ from .coordinator import ResponseStreamCoordinator +from .session import RESPONSE_SESSION_NODE_TYPES -__all__ = ["ResponseStreamCoordinator"] +__all__ = ["RESPONSE_SESSION_NODE_TYPES", "ResponseStreamCoordinator"] diff --git a/api/dify_graph/graph_engine/response_coordinator/session.py b/api/dify_graph/graph_engine/response_coordinator/session.py index 0548e88d93..99ac1b5edf 100644 --- a/api/dify_graph/graph_engine/response_coordinator/session.py +++ b/api/dify_graph/graph_engine/response_coordinator/session.py @@ -3,19 +3,34 @@ Internal response session management for response coordinator. This module contains the private ResponseSession class used internally by ResponseStreamCoordinator to manage streaming sessions. + +`RESPONSE_SESSION_NODE_TYPES` is intentionally mutable so downstream applications +can opt additional response-capable node types into session creation without +patching the coordinator. """ from __future__ import annotations from dataclasses import dataclass +from typing import Protocol, cast -from dify_graph.nodes.answer.answer_node import AnswerNode +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.template import Template -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.knowledge_index import KnowledgeIndexNode from dify_graph.runtime.graph_runtime_state import NodeProtocol +class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): + """Structural contract required from nodes that can open a response session.""" + + def get_streaming_template(self) -> Template: ... + + +RESPONSE_SESSION_NODE_TYPES: list[NodeType] = [ + BuiltinNodeTypes.ANSWER, + BuiltinNodeTypes.END, +] + + @dataclass class ResponseSession: """ @@ -33,10 +48,9 @@ class ResponseSession: """ Create a ResponseSession from a response-capable node. - The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer, - but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides: - - `id: str` - - `get_streaming_template() -> Template` + The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer. + At runtime this must be a node whose `node_type` is listed in `RESPONSE_SESSION_NODE_TYPES` + and which implements `get_streaming_template()`. Args: node: Node from the materialized workflow graph. @@ -47,11 +61,22 @@ class ResponseSession: Raises: TypeError: If node is not a supported response node type. """ - if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode): - raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode") + if node.node_type not in RESPONSE_SESSION_NODE_TYPES: + supported_node_types = ", ".join(RESPONSE_SESSION_NODE_TYPES) + raise TypeError( + "ResponseSession.from_node only supports node types in " + f"RESPONSE_SESSION_NODE_TYPES: {supported_node_types}" + ) + + response_node = cast(_ResponseSessionNodeProtocol, node) + try: + template = response_node.get_streaming_template() + except AttributeError as exc: + raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc + return cls( node_id=node.id, - template=node.get_streaming_template(), + template=template, ) def is_complete(self) -> bool: diff --git a/api/dify_graph/node_events/node.py b/api/dify_graph/node_events/node.py index 481e793267..2e3973b8fa 100644 --- a/api/dify_graph/node_events/node.py +++ b/api/dify_graph/node_events/node.py @@ -1,9 +1,9 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from datetime import datetime +from typing import Any from pydantic import Field -from core.rag.entities.citation_metadata import RetrievalSourceMetadata from dify_graph.entities.pause_reason import PauseReason from dify_graph.file import File from dify_graph.model_runtime.entities.llm_entities import LLMUsage @@ -13,7 +13,7 @@ from .base import NodeEventBase class RunRetrieverResourceEvent(NodeEventBase): - retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") + retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") context: str = Field(..., description="context") context_files: list[File] | None = Field(default=None, description="context files") diff --git a/api/dify_graph/nodes/__init__.py b/api/dify_graph/nodes/__init__.py index d113ad5e70..0223149bb8 100644 --- a/api/dify_graph/nodes/__init__.py +++ b/api/dify_graph/nodes/__init__.py @@ -1,3 +1,3 @@ -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes -__all__ = ["NodeType"] +__all__ = ["BuiltinNodeTypes"] diff --git a/api/dify_graph/nodes/answer/answer_node.py b/api/dify_graph/nodes/answer/answer_node.py index c829b892cc..4286e1a492 100644 --- a/api/dify_graph/nodes/answer/answer_node.py +++ b/api/dify_graph/nodes/answer/answer_node.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from typing import Any -from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.answer.entities import AnswerNodeData from dify_graph.nodes.base.node import Node @@ -11,7 +11,7 @@ from dify_graph.variables import ArrayFileSegment, FileSegment, Segment class AnswerNode(Node[AnswerNodeData]): - node_type = NodeType.ANSWER + node_type = BuiltinNodeTypes.ANSWER execution_type = NodeExecutionType.RESPONSE @classmethod diff --git a/api/dify_graph/nodes/answer/entities.py b/api/dify_graph/nodes/answer/entities.py index 3cc1d6572e..cd82df1ac4 100644 --- a/api/dify_graph/nodes/answer/entities.py +++ b/api/dify_graph/nodes/answer/entities.py @@ -4,7 +4,7 @@ from enum import StrEnum, auto from pydantic import BaseModel, Field from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType class AnswerNodeData(BaseNodeData): @@ -12,7 +12,7 @@ class AnswerNodeData(BaseNodeData): Answer Node Data. """ - type: NodeType = NodeType.ANSWER + type: NodeType = BuiltinNodeTypes.ANSWER answer: str = Field(..., description="answer template string") diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index 2044b09333..c6f54ce672 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -1,9 +1,7 @@ from __future__ import annotations -import importlib import logging import operator -import pkgutil from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod @@ -161,7 +159,7 @@ class Node(Generic[NodeDataT]): Example: class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE # No need to implement _get_title, _get_error_strategy, etc. """ super().__init_subclass__(**kwargs) @@ -179,7 +177,8 @@ class Node(Generic[NodeDataT]): # Skip base class itself if cls is Node: return - # Only register production node implementations defined under dify_graph.nodes.* + # Only register production node implementations defined under the + # canonical workflow namespaces. # This prevents test helper subclasses from polluting the global registry and # accidentally overriding real node types (e.g., a test Answer node). module_name = getattr(cls, "__module__", "") @@ -187,7 +186,7 @@ class Node(Generic[NodeDataT]): node_type = cls.node_type version = cls.version() bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith("dify_graph.nodes."): + if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")): # Production node definitions take precedence and may override bucket[version] = cls # type: ignore[index] else: @@ -203,6 +202,7 @@ class Node(Generic[NodeDataT]): else: latest_key = max(version_keys) if version_keys else version bucket["latest"] = bucket[latest_key] + Node._registry_version += 1 @classmethod def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: @@ -237,6 +237,11 @@ class Node(Generic[NodeDataT]): # Global registry populated via __init_subclass__ _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} + _registry_version: ClassVar[int] = 0 + + @classmethod + def get_registry_version(cls) -> int: + return cls._registry_version def __init__( self, @@ -269,6 +274,10 @@ class Node(Generic[NodeDataT]): """Validate shared graph node payloads against the subclass-declared NodeData model.""" return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) + def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: + """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" + self._node_data = self.validate_node_data(cast(BaseNodeData, data)) + def post_init(self) -> None: """Optional hook for subclasses requiring extra initialization.""" return @@ -489,29 +498,19 @@ class Node(Generic[NodeDataT]): def version(cls) -> str: """`node_version` returns the version of current node type.""" # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so - # `Node.get_node_type_classes_mapping()` can resolve numeric versions and `latest`. + # registry lookups can resolve numeric versions and `latest`. raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: - """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. + """Return a read-only view of the currently registered node classes. - Import all modules under dify_graph.nodes so subclasses register themselves on import. - Callers that rely on workflow-local nodes defined outside `dify_graph.nodes` must import - those modules before invoking this method so they can register through `__init_subclass__`. - We then return a readonly view of the registry to avoid accidental mutation. + This accessor intentionally performs no imports. The embedding layer that + owns bootstrap (for example `core.workflow.node_factory`) must import any + extension node packages before calling it so their subclasses register via + `__init_subclass__`. """ - # Import all node modules to ensure they are loaded (thus registered) - import dify_graph.nodes as _nodes_pkg - - for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."): - # Avoid importing modules that depend on the registry to prevent circular imports. - if _modname == "dify_graph.nodes.node_mapping": - continue - importlib.import_module(_modname) - - # Return a readonly view so callers can't mutate the registry by accident - return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()} + return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()} @property def retry(self) -> bool: @@ -786,11 +785,16 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: + from core.rag.entities.citation_metadata import RetrievalSourceMetadata + + retriever_resources = [ + RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources + ] return NodeRunRetrieverResourceEvent( id=self.execution_id, node_id=self._node_id, node_type=self.node_type, - retriever_resources=event.retriever_resources, + retriever_resources=retriever_resources, context=event.context, node_version=self.version(), ) diff --git a/api/dify_graph/nodes/code/code_node.py b/api/dify_graph/nodes/code/code_node.py index ac8d6463b9..82d5fced62 100644 --- a/api/dify_graph/nodes/code/code_node.py +++ b/api/dify_graph/nodes/code/code_node.py @@ -4,7 +4,7 @@ from textwrap import dedent from typing import TYPE_CHECKING, Any, Protocol, cast from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData @@ -72,7 +72,7 @@ _DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { class CodeNode(Node[CodeNodeData]): - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE _limits: CodeNodeLimits def __init__( diff --git a/api/dify_graph/nodes/code/entities.py b/api/dify_graph/nodes/code/entities.py index 25e46226e1..55b4ee4862 100644 --- a/api/dify_graph/nodes/code/entities.py +++ b/api/dify_graph/nodes/code/entities.py @@ -4,7 +4,7 @@ from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.entities import VariableSelector from dify_graph.variables.types import SegmentType @@ -40,7 +40,7 @@ class CodeNodeData(BaseNodeData): Code Node Data. """ - type: NodeType = NodeType.CODE + type: NodeType = BuiltinNodeTypes.CODE class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] diff --git a/api/dify_graph/nodes/datasource/__init__.py b/api/dify_graph/nodes/datasource/__init__.py deleted file mode 100644 index f6ec44cb77..0000000000 --- a/api/dify_graph/nodes/datasource/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .datasource_node import DatasourceNode - -__all__ = ["DatasourceNode"] diff --git a/api/dify_graph/nodes/document_extractor/entities.py b/api/dify_graph/nodes/document_extractor/entities.py index 9f42d2e605..1110cc2710 100644 --- a/api/dify_graph/nodes/document_extractor/entities.py +++ b/api/dify_graph/nodes/document_extractor/entities.py @@ -2,11 +2,11 @@ from collections.abc import Sequence from dataclasses import dataclass from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType class DocumentExtractorNodeData(BaseNodeData): - type: NodeType = NodeType.DOCUMENT_EXTRACTOR + type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR variable_selector: Sequence[str] diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py index fe51b1963e..27196f1aca 100644 --- a/api/dify_graph/nodes/document_extractor/node.py +++ b/api/dify_graph/nodes/document_extractor/node.py @@ -22,7 +22,7 @@ from docx.table import Table from docx.text.paragraph import Paragraph from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, file_manager from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -46,7 +46,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): Supports plain text, PDF, and DOC/DOCX files. """ - node_type = NodeType.DOCUMENT_EXTRACTOR + node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR @classmethod def version(cls) -> str: diff --git a/api/dify_graph/nodes/end/end_node.py b/api/dify_graph/nodes/end/end_node.py index 7aa526b85b..1f5cfab22b 100644 --- a/api/dify_graph/nodes/end/end_node.py +++ b/api/dify_graph/nodes/end/end_node.py @@ -1,4 +1,4 @@ -from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.template import Template @@ -6,7 +6,7 @@ from dify_graph.nodes.end.entities import EndNodeData class EndNode(Node[EndNodeData]): - node_type = NodeType.END + node_type = BuiltinNodeTypes.END execution_type = NodeExecutionType.RESPONSE @classmethod diff --git a/api/dify_graph/nodes/end/entities.py b/api/dify_graph/nodes/end/entities.py index 69cd1dd8f5..be7f0c8de8 100644 --- a/api/dify_graph/nodes/end/entities.py +++ b/api/dify_graph/nodes/end/entities.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.entities import OutputVariableEntity @@ -10,7 +10,7 @@ class EndNodeData(BaseNodeData): END Node Data. """ - type: NodeType = NodeType.END + type: NodeType = BuiltinNodeTypes.END outputs: list[OutputVariableEntity] diff --git a/api/dify_graph/nodes/http_request/entities.py b/api/dify_graph/nodes/http_request/entities.py index 46e08ea1a0..f594d58ae6 100644 --- a/api/dify_graph/nodes/http_request/entities.py +++ b/api/dify_graph/nodes/http_request/entities.py @@ -9,7 +9,7 @@ import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" @@ -90,7 +90,7 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ - type: NodeType = NodeType.HTTP_REQUEST + type: NodeType = BuiltinNodeTypes.HTTP_REQUEST method: Literal[ "get", "post", diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index 3895ae92c0..b17c820a80 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base import variable_template_parser @@ -33,7 +33,7 @@ if TYPE_CHECKING: class HttpRequestNode(Node[HttpRequestNodeData]): - node_type = NodeType.HTTP_REQUEST + node_type = BuiltinNodeTypes.HTTP_REQUEST def __init__( self, diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py index 642c2143e5..7936e47213 100644 --- a/api/dify_graph/nodes/human_input/entities.py +++ b/api/dify_graph/nodes/human_input/entities.py @@ -11,7 +11,7 @@ from typing import Annotated, Any, ClassVar, Literal, Self from pydantic import BaseModel, Field, field_validator, model_validator from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.runtime import VariablePool from dify_graph.variables.consts import SELECTORS_LENGTH @@ -215,7 +215,7 @@ class UserAction(BaseModel): class HumanInputNodeData(BaseNodeData): """Human Input node data.""" - type: NodeType = NodeType.HUMAN_INPUT + type: NodeType = BuiltinNodeTypes.HUMAN_INPUT delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) form_content: str = "" inputs: list[FormInput] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index 3a167d122b..794e33d92e 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus from dify_graph.node_events import ( HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) class HumanInputNode(Node[HumanInputNodeData]): - node_type = NodeType.HUMAN_INPUT + node_type = BuiltinNodeTypes.HUMAN_INPUT execution_type = NodeExecutionType.BRANCH _BRANCH_SELECTION_KEYS: tuple[str, ...] = ( diff --git a/api/dify_graph/nodes/if_else/entities.py b/api/dify_graph/nodes/if_else/entities.py index c9bb1cdc7f..ff09f3c023 100644 --- a/api/dify_graph/nodes/if_else/entities.py +++ b/api/dify_graph/nodes/if_else/entities.py @@ -3,7 +3,7 @@ from typing import Literal from pydantic import BaseModel, Field from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.utils.condition.entities import Condition @@ -12,7 +12,7 @@ class IfElseNodeData(BaseNodeData): If Else Node Data. """ - type: NodeType = NodeType.IF_ELSE + type: NodeType = BuiltinNodeTypes.IF_ELSE class Case(BaseModel): """ diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py index 4b6d30c279..7c0370e48c 100644 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ b/api/dify_graph/nodes/if_else/if_else_node.py @@ -3,7 +3,7 @@ from typing import Any, Literal from typing_extensions import deprecated -from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.if_else.entities import IfElseNodeData @@ -13,7 +13,7 @@ from dify_graph.utils.condition.processor import ConditionProcessor class IfElseNode(Node[IfElseNodeData]): - node_type = NodeType.IF_ELSE + node_type = BuiltinNodeTypes.IF_ELSE execution_type = NodeExecutionType.BRANCH @classmethod diff --git a/api/dify_graph/nodes/iteration/entities.py b/api/dify_graph/nodes/iteration/entities.py index 6d61c12352..58fd112b12 100644 --- a/api/dify_graph/nodes/iteration/entities.py +++ b/api/dify_graph/nodes/iteration/entities.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import Field from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState @@ -19,7 +19,7 @@ class IterationNodeData(BaseIterationNodeData): Iteration Node Data. """ - type: NodeType = NodeType.ITERATION + type: NodeType = BuiltinNodeTypes.ITERATION parent_loop_id: str | None = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector @@ -34,7 +34,7 @@ class IterationStartNodeData(BaseNodeData): Iteration Start Node Data. """ - type: NodeType = NodeType.ITERATION_START + type: NodeType = BuiltinNodeTypes.ITERATION_START class IterationState(BaseIterationState): diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index 1d626f4bd6..f63ba0bc48 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -9,8 +9,8 @@ from typing_extensions import TypeIs from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( + BuiltinNodeTypes, NodeExecutionType, - NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -62,7 +62,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): Iteration Node. """ - node_type = NodeType.ITERATION + node_type = BuiltinNodeTypes.ITERATION execution_type = NodeExecutionType.CONTAINER @classmethod @@ -485,12 +485,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # variable selector to variable mapping try: - # Get node class - from dify_graph.nodes.node_mapping import get_node_type_classes_mapping - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) node_type = typed_sub_node_config["data"].type - node_mapping = get_node_type_classes_mapping() + node_mapping = Node.get_node_type_classes_mapping() if node_type not in node_mapping: continue node_version = str(typed_sub_node_config["data"].version) @@ -563,7 +560,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found") current_index = index_variable.value for event in rst: - if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START: + if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START: continue if isinstance(event, GraphNodeEventBase): diff --git a/api/dify_graph/nodes/iteration/iteration_start_node.py b/api/dify_graph/nodes/iteration/iteration_start_node.py index 2e1f555ed2..a8ecf3d83b 100644 --- a/api/dify_graph/nodes/iteration/iteration_start_node.py +++ b/api/dify_graph/nodes/iteration/iteration_start_node.py @@ -1,4 +1,4 @@ -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.iteration.entities import IterationStartNodeData @@ -9,7 +9,7 @@ class IterationStartNode(Node[IterationStartNodeData]): Iteration Start Node. """ - node_type = NodeType.ITERATION_START + node_type = BuiltinNodeTypes.ITERATION_START @classmethod def version(cls) -> str: diff --git a/api/dify_graph/nodes/knowledge_index/__init__.py b/api/dify_graph/nodes/knowledge_index/__init__.py deleted file mode 100644 index 23897a1e42..0000000000 --- a/api/dify_graph/nodes/knowledge_index/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .knowledge_index_node import KnowledgeIndexNode - -__all__ = ["KnowledgeIndexNode"] diff --git a/api/dify_graph/nodes/knowledge_retrieval/__init__.py b/api/dify_graph/nodes/knowledge_retrieval/__init__.py deleted file mode 100644 index 4d4a4cbd9f..0000000000 --- a/api/dify_graph/nodes/knowledge_retrieval/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .knowledge_retrieval_node import KnowledgeRetrievalNode - -__all__ = ["KnowledgeRetrievalNode"] diff --git a/api/dify_graph/nodes/list_operator/entities.py b/api/dify_graph/nodes/list_operator/entities.py index a91cfab8de..41b3a40b78 100644 --- a/api/dify_graph/nodes/list_operator/entities.py +++ b/api/dify_graph/nodes/list_operator/entities.py @@ -4,7 +4,7 @@ from enum import StrEnum from pydantic import BaseModel, Field from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType class FilterOperator(StrEnum): @@ -63,7 +63,7 @@ class ExtractConfig(BaseModel): class ListOperatorNodeData(BaseNodeData): - type: NodeType = NodeType.LIST_OPERATOR + type: NodeType = BuiltinNodeTypes.LIST_OPERATOR variable: Sequence[str] = Field(default_factory=list) filter_by: FilterBy order_by: OrderByConfig diff --git a/api/dify_graph/nodes/list_operator/node.py b/api/dify_graph/nodes/list_operator/node.py index d2fdadc29c..dc8b8904f7 100644 --- a/api/dify_graph/nodes/list_operator/node.py +++ b/api/dify_graph/nodes/list_operator/node.py @@ -1,7 +1,7 @@ from collections.abc import Callable, Sequence from typing import Any, TypeAlias, TypeVar -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.file import File from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -35,7 +35,7 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: class ListOperatorNode(Node[ListOperatorNodeData]): - node_type = NodeType.LIST_OPERATOR + node_type = BuiltinNodeTypes.LIST_OPERATOR @classmethod def version(cls) -> str: diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py index 71728aa227..6ca01a21da 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode from dify_graph.nodes.base.entities import VariableSelector @@ -60,7 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): - type: NodeType = NodeType.LLM + type: NodeType = BuiltinNodeTypes.LLM model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index b88ff404c0..c3529867b7 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -17,12 +17,12 @@ from core.llm_generator.output_parser.structured_output import invoke_llm_with_s from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.signature import sign_upload_file from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( + BuiltinNodeTypes, NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, @@ -104,7 +104,7 @@ logger = logging.getLogger(__name__) class LLMNode(Node[LLMNodeData]): - node_type = NodeType.LLM + node_type = BuiltinNodeTypes.LLM # Compiled regex for extracting blocks (with compatibility for attributes) _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) @@ -677,7 +677,7 @@ class LLMNode(Node[LLMNodeData]): ) elif isinstance(context_value_variable, ArraySegment): context_str = "" - original_retriever_resource: list[RetrievalSourceMetadata] = [] + original_retriever_resource: list[dict[str, Any]] = [] context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): @@ -693,11 +693,14 @@ class LLMNode(Node[LLMNodeData]): retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) + segment_id = retriever_resource.get("segment_id") + if not segment_id: + continue attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) .where( - SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, + SegmentAttachmentBinding.segment_id == segment_id, ) ).all() if attachments_with_bindings: @@ -723,7 +726,7 @@ class LLMNode(Node[LLMNodeData]): context_files=context_files, ) - def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: + def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None: if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -731,28 +734,26 @@ class LLMNode(Node[LLMNodeData]): ): metadata = context_dict.get("metadata", {}) - source = RetrievalSourceMetadata( - position=metadata.get("position"), - dataset_id=metadata.get("dataset_id"), - dataset_name=metadata.get("dataset_name"), - document_id=metadata.get("document_id"), - document_name=metadata.get("document_name"), - data_source_type=metadata.get("data_source_type"), - segment_id=metadata.get("segment_id"), - retriever_from=metadata.get("retriever_from"), - score=metadata.get("score"), - hit_count=metadata.get("segment_hit_count"), - word_count=metadata.get("segment_word_count"), - segment_position=metadata.get("segment_position"), - index_node_hash=metadata.get("segment_index_node_hash"), - content=context_dict.get("content"), - page=metadata.get("page"), - doc_metadata=metadata.get("doc_metadata"), - files=context_dict.get("files"), - summary=context_dict.get("summary"), - ) - - return source + return { + "position": metadata.get("position"), + "dataset_id": metadata.get("dataset_id"), + "dataset_name": metadata.get("dataset_name"), + "document_id": metadata.get("document_id"), + "document_name": metadata.get("document_name"), + "data_source_type": metadata.get("data_source_type"), + "segment_id": metadata.get("segment_id"), + "retriever_from": metadata.get("retriever_from"), + "score": metadata.get("score"), + "hit_count": metadata.get("segment_hit_count"), + "word_count": metadata.get("segment_word_count"), + "segment_position": metadata.get("segment_position"), + "index_node_hash": metadata.get("segment_index_node_hash"), + "content": context_dict.get("content"), + "page": metadata.get("page"), + "doc_metadata": metadata.get("doc_metadata"), + "files": context_dict.get("files"), + "summary": context_dict.get("summary"), + } return None diff --git a/api/dify_graph/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py index 8a3df5c234..f0bfad5a0f 100644 --- a/api/dify_graph/nodes/loop/entities.py +++ b/api/dify_graph/nodes/loop/entities.py @@ -4,7 +4,7 @@ from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState from dify_graph.utils.condition.entities import Condition from dify_graph.variables.types import SegmentType @@ -41,7 +41,7 @@ class LoopVariableData(BaseModel): class LoopNodeData(BaseLoopNodeData): - type: NodeType = NodeType.LOOP + type: NodeType = BuiltinNodeTypes.LOOP loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] @@ -61,7 +61,7 @@ class LoopStartNodeData(BaseNodeData): Loop Start Node Data. """ - type: NodeType = NodeType.LOOP_START + type: NodeType = BuiltinNodeTypes.LOOP_START class LoopEndNodeData(BaseNodeData): @@ -69,7 +69,7 @@ class LoopEndNodeData(BaseNodeData): Loop End Node Data. """ - type: NodeType = NodeType.LOOP_END + type: NodeType = BuiltinNodeTypes.LOOP_END class LoopState(BaseLoopState): diff --git a/api/dify_graph/nodes/loop/loop_end_node.py b/api/dify_graph/nodes/loop/loop_end_node.py index 73ac5da927..0287708fb3 100644 --- a/api/dify_graph/nodes/loop/loop_end_node.py +++ b/api/dify_graph/nodes/loop/loop_end_node.py @@ -1,4 +1,4 @@ -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.loop.entities import LoopEndNodeData @@ -9,7 +9,7 @@ class LoopEndNode(Node[LoopEndNodeData]): Loop End Node. """ - node_type = NodeType.LOOP_END + node_type = BuiltinNodeTypes.LOOP_END @classmethod def version(cls) -> str: diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 1a8774f445..3c546ffa23 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Literal, cast from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( + BuiltinNodeTypes, NodeExecutionType, - NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -46,7 +46,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): Loop Node. """ - node_type = NodeType.LOOP + node_type = BuiltinNodeTypes.LOOP execution_type = NodeExecutionType.CONTAINER @classmethod @@ -250,11 +250,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): if isinstance(event, GraphNodeEventBase): self._append_loop_info_to_event(event=event, loop_run_index=current_index) - if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START: + if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.LOOP_START: continue if isinstance(event, GraphNodeEventBase): yield event - if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END: + if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: reach_break_node = True if isinstance(event, GraphRunFailedEvent): raise Exception(event.error) @@ -315,12 +315,9 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # variable selector to variable mapping try: - # Get node class - from dify_graph.nodes.node_mapping import get_node_type_classes_mapping - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) node_type = typed_sub_node_config["data"].type - node_mapping = get_node_type_classes_mapping() + node_mapping = Node.get_node_type_classes_mapping() if node_type not in node_mapping: continue node_version = str(typed_sub_node_config["data"].version) diff --git a/api/dify_graph/nodes/loop/loop_start_node.py b/api/dify_graph/nodes/loop/loop_start_node.py index f469c8286e..e171b4df2f 100644 --- a/api/dify_graph/nodes/loop/loop_start_node.py +++ b/api/dify_graph/nodes/loop/loop_start_node.py @@ -1,4 +1,4 @@ -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.loop.entities import LoopStartNodeData @@ -9,7 +9,7 @@ class LoopStartNode(Node[LoopStartNodeData]): Loop Start Node. """ - node_type = NodeType.LOOP_START + node_type = BuiltinNodeTypes.LOOP_START @classmethod def version(cls) -> str: diff --git a/api/dify_graph/nodes/node_mapping.py b/api/dify_graph/nodes/node_mapping.py deleted file mode 100644 index e0f5524a04..0000000000 --- a/api/dify_graph/nodes/node_mapping.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Mapping - -from dify_graph.enums import NodeType -from dify_graph.nodes.base.node import Node - -LATEST_VERSION = "latest" - - -def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]: - """Return the live node registry after importing all `dify_graph.nodes` modules.""" - return Node.get_node_type_classes_mapping() - - -def resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: - node_mapping = get_node_type_classes_mapping().get(node_type) - if not node_mapping: - raise ValueError(f"No class mapping found for node type: {node_type}") - - latest_node_class = node_mapping.get(LATEST_VERSION) - matched_node_class = node_mapping.get(node_version) - node_class = matched_node_class or latest_node_class - if not node_class: - raise ValueError(f"No latest version class found for node type: {node_type}") - return node_class - - -# Snapshot kept for compatibility with older tests; production paths should use the live helpers. -NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping() diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py index 8f8a278d5b..2fb042c16c 100644 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ b/api/dify_graph/nodes/parameter_extractor/entities.py @@ -9,7 +9,7 @@ from pydantic import ( from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig from dify_graph.variables.types import SegmentType @@ -84,7 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData): Parameter Extractor Node Data. """ - type: NodeType = NodeType.PARAMETER_EXTRACTOR + type: NodeType = BuiltinNodeTypes.PARAMETER_EXTRACTOR model: ModelConfig query: list[str] parameters: list[ParameterConfig] diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 68bd15db30..3913a27697 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -12,7 +12,7 @@ from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( - NodeType, + BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -97,7 +97,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): Parameter Extractor Node. """ - node_type = NodeType.PARAMETER_EXTRACTOR + node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR _model_instance: ModelInstance _credentials_provider: "CredentialsProvider" diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py index 77a6c70c28..0c1601d439 100644 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ b/api/dify_graph/nodes/question_classifier/entities.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.llm import ModelConfig, VisionConfig @@ -12,7 +12,7 @@ class ClassConfig(BaseModel): class QuestionClassifierNodeData(BaseNodeData): - type: NodeType = NodeType.QUESTION_CLASSIFIER + type: NodeType = BuiltinNodeTypes.QUESTION_CLASSIFIER query_variable_selector: list[str] model: ModelConfig classes: list[ClassConfig] diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index a61bca4ea9..84e77a460c 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -9,8 +9,8 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( + BuiltinNodeTypes, NodeExecutionType, - NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -50,7 +50,7 @@ if TYPE_CHECKING: class QuestionClassifierNode(Node[QuestionClassifierNodeData]): - node_type = NodeType.QUESTION_CLASSIFIER + node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER execution_type = NodeExecutionType.BRANCH _file_outputs: list["File"] diff --git a/api/dify_graph/nodes/start/entities.py b/api/dify_graph/nodes/start/entities.py index cbf7348360..92ebd1a2ec 100644 --- a/api/dify_graph/nodes/start/entities.py +++ b/api/dify_graph/nodes/start/entities.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from pydantic import Field from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.variables.input_entities import VariableEntity @@ -12,5 +12,5 @@ class StartNodeData(BaseNodeData): Start Node Data """ - type: NodeType = NodeType.START + type: NodeType = BuiltinNodeTypes.START variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/start/start_node.py b/api/dify_graph/nodes/start/start_node.py index c09ead0124..5e6055ea34 100644 --- a/api/dify_graph/nodes/start/start_node.py +++ b/api/dify_graph/nodes/start/start_node.py @@ -3,7 +3,7 @@ from typing import Any from jsonschema import Draft7Validator, ValidationError from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.start.entities import StartNodeData @@ -11,7 +11,7 @@ from dify_graph.variables.input_entities import VariableEntityType class StartNode(Node[StartNodeData]): - node_type = NodeType.START + node_type = BuiltinNodeTypes.START execution_type = NodeExecutionType.ROOT @classmethod diff --git a/api/dify_graph/nodes/template_transform/entities.py b/api/dify_graph/nodes/template_transform/entities.py index 2a79a82870..ac29239958 100644 --- a/api/dify_graph/nodes/template_transform/entities.py +++ b/api/dify_graph/nodes/template_transform/entities.py @@ -1,5 +1,5 @@ from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.entities import VariableSelector @@ -8,6 +8,6 @@ class TemplateTransformNodeData(BaseNodeData): Template Transform Node Data. """ - type: NodeType = NodeType.TEMPLATE_TRANSFORM + type: NodeType = BuiltinNodeTypes.TEMPLATE_TRANSFORM variables: list[VariableSelector] template: str diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py index 9dfb535342..dc6fce2b0a 100644 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ b/api/dify_graph/nodes/template_transform/template_transform_node.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData @@ -19,7 +19,7 @@ DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 class TemplateTransformNode(Node[TemplateTransformNodeData]): - node_type = NodeType.TEMPLATE_TRANSFORM + node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM _template_renderer: Jinja2TemplateRenderer _max_output_length: int diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py index 4ba8c16e85..b041ee66fd 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -5,7 +5,7 @@ from pydantic_core.core_schema import ValidationInfo from core.tools.entities.tool_entities import ToolProviderType from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType class ToolEntity(BaseModel): @@ -33,7 +33,7 @@ class ToolEntity(BaseModel): class ToolNodeData(BaseNodeData, ToolEntity): - type: NodeType = NodeType.TOOL + type: NodeType = BuiltinNodeTypes.TOOL class ToolInput(BaseModel): # TODO: check this type diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index ec7386981e..598f0da92e 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -9,7 +9,7 @@ from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( - NodeType, + BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, @@ -42,7 +42,7 @@ class ToolNode(Node[ToolNodeData]): Tool Node """ - node_type = NodeType.TOOL + node_type = BuiltinNodeTypes.TOOL def __init__( self, diff --git a/api/dify_graph/nodes/trigger_schedule/__init__.py b/api/dify_graph/nodes/trigger_schedule/__init__.py deleted file mode 100644 index c9b3ae6a0d..0000000000 --- a/api/dify_graph/nodes/trigger_schedule/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dify_graph.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode - -__all__ = ["TriggerScheduleNode"] diff --git a/api/dify_graph/nodes/variable_aggregator/entities.py b/api/dify_graph/nodes/variable_aggregator/entities.py index fec4c4474c..4779ebd9a9 100644 --- a/api/dify_graph/nodes/variable_aggregator/entities.py +++ b/api/dify_graph/nodes/variable_aggregator/entities.py @@ -1,7 +1,7 @@ from pydantic import BaseModel from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.variables.types import SegmentType @@ -29,7 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData): Variable Aggregator Node Data. """ - type: NodeType = NodeType.VARIABLE_AGGREGATOR + type: NodeType = BuiltinNodeTypes.VARIABLE_AGGREGATOR output_type: str variables: list[list[str]] advanced_settings: AdvancedSettings | None = None diff --git a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py b/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py index 98ab8105fe..7d26de6232 100644 --- a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,6 +1,6 @@ from collections.abc import Mapping -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_aggregator.entities import VariableAggregatorNodeData @@ -8,7 +8,7 @@ from dify_graph.variables.segments import Segment class VariableAggregatorNode(Node[VariableAggregatorNodeData]): - node_type = NodeType.VARIABLE_AGGREGATOR + node_type = BuiltinNodeTypes.VARIABLE_AGGREGATOR @classmethod def version(cls) -> str: diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py index 1d17b981ba..f9b261b191 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_assigner.common import helpers as common_helpers @@ -18,7 +18,7 @@ if TYPE_CHECKING: class VariableAssignerNode(Node[VariableAssignerData]): - node_type = NodeType.VARIABLE_ASSIGNER + node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER def __init__( self, diff --git a/api/dify_graph/nodes/variable_assigner/v1/node_data.py b/api/dify_graph/nodes/variable_assigner/v1/node_data.py index a75a2397ba..57acb29535 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node_data.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node_data.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from enum import StrEnum from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType class WriteMode(StrEnum): @@ -12,7 +12,7 @@ class WriteMode(StrEnum): class VariableAssignerData(BaseNodeData): - type: NodeType = NodeType.VARIABLE_ASSIGNER + type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/dify_graph/nodes/variable_assigner/v2/entities.py b/api/dify_graph/nodes/variable_assigner/v2/entities.py index ca3a94b777..2b2bbe85de 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/entities.py +++ b/api/dify_graph/nodes/variable_assigner/v2/entities.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, Field from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from .enums import InputType, Operation @@ -23,6 +23,6 @@ class VariableOperationItem(BaseModel): class VariableAssignerNodeData(BaseNodeData): - type: NodeType = NodeType.VARIABLE_ASSIGNER + type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER version: str = "2" items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py index 771609ceb6..f04a6b3b80 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ b/api/dify_graph/nodes/variable_assigner/v2/node.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_assigner.common import helpers as common_helpers @@ -52,7 +52,7 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ class VariableAssignerNode(Node[VariableAssignerNodeData]): - node_type = NodeType.VARIABLE_ASSIGNER + node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER def __init__( self, diff --git a/api/dify_graph/repositories/summary_index_service_protocol.py b/api/dify_graph/repositories/summary_index_service_protocol.py deleted file mode 100644 index cbcfdd2a77..0000000000 --- a/api/dify_graph/repositories/summary_index_service_protocol.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Protocol - - -class SummaryIndexServiceProtocol(Protocol): - def generate_and_vectorize_summary( - self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None - ): ... diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 5c02a16a7d..c43e99f0f4 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -2,7 +2,7 @@ import logging from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced @@ -16,7 +16,7 @@ def handle(sender, **kwargs): if synced_draft_workflow is None: return for node_data in synced_draft_workflow.graph_dict.get("nodes", []): - if node_data.get("data", {}).get("type") == NodeType.TOOL: + if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: tool_entity = ToolEntity.model_validate(node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( diff --git a/api/events/event_handlers/sync_workflow_schedule_when_app_published.py b/api/events/event_handlers/sync_workflow_schedule_when_app_published.py index 90f562d167..168513fc04 100644 --- a/api/events/event_handlers/sync_workflow_schedule_when_app_published.py +++ b/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @@ -4,7 +4,7 @@ from typing import cast from sqlalchemy import select from sqlalchemy.orm import Session -from dify_graph.nodes.trigger_schedule.entities import SchedulePlanUpdate +from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models import AppMode, Workflow, WorkflowSchedulePlan diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 8da33d03b9..92bc9db075 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -2,8 +2,8 @@ from typing import cast from sqlalchemy import select -from dify_graph.nodes import NodeType -from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from dify_graph.nodes import BuiltinNodeTypes from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models.dataset import AppDatasetJoin @@ -53,7 +53,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: # fetch all knowledge retrieval nodes knowledge_retrieval_nodes = [ - node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL + node for node in nodes if node.get("data", {}).get("type") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL ] if not knowledge_retrieval_nodes: diff --git a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py index fd211a3e55..b3917d5622 100644 --- a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @@ -3,7 +3,7 @@ from typing import cast from sqlalchemy import select from sqlalchemy.orm import Session -from dify_graph.nodes import NodeType +from core.trigger.constants import TRIGGER_NODE_TYPES from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models import AppMode @@ -98,7 +98,7 @@ def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]: return [] nodes = graph.get("nodes", []) - trigger_types = {NodeType.TRIGGER_WEBHOOK.value, NodeType.TRIGGER_SCHEDULE.value, NodeType.TRIGGER_PLUGIN.value} + trigger_types = TRIGGER_NODE_TYPES trigger_infos = [ { diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index bd1c08d96e..d84c0bc432 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -19,7 +19,6 @@ from sqlalchemy.orm import sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from dify_graph.entities import WorkflowNodeExecution from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.enums import NodeType from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter @@ -78,7 +77,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut index=safe_int(data.get("index", 0)), predecessor_node_id=data.get("predecessor_node_id"), node_id=data.get("node_id", ""), - node_type=NodeType(data.get("node_type", "start")), + node_type=data.get("node_type", "start"), title=data.get("title", ""), inputs=inputs, process_data=process_data, @@ -185,7 +184,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): ("predecessor_node_id", domain_model.predecessor_node_id or ""), ("node_execution_id", domain_model.node_execution_id or ""), ("node_id", domain_model.node_id), - ("node_type", domain_model.node_type.value), + ("node_type", domain_model.node_type), ("title", domain_model.title), ( "inputs", diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index fc84147e01..544ef3fe18 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -9,7 +9,7 @@ from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.file.models import File from dify_graph.graph_events import GraphNodeEventBase from dify_graph.nodes.base.node import Node @@ -84,21 +84,17 @@ class DefaultNodeOTelParser: span.set_attribute("node.id", node.id) if node.execution_id: span.set_attribute("node.execution_id", node.execution_id) - if hasattr(node, "node_type") and node.node_type: - span.set_attribute("node.type", node.node_type.value) + span.set_attribute("node.type", node.node_type) span.set_attribute(GenAIAttributes.FRAMEWORK, "dify") - node_type = getattr(node, "node_type", None) - if isinstance(node_type, NodeType): - if node_type == NodeType.LLM: - span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM") - elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: - span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER") - elif node_type == NodeType.TOOL: - span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL") - else: - span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK") + node_type = node.node_type + if node_type == BuiltinNodeTypes.LLM: + span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM") + elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: + span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER") + elif node_type == BuiltinNodeTypes.TOOL: + span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL") else: span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK") diff --git a/api/models/enums.py b/api/models/enums.py index 66e3e4b332..eb478fe02c 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,6 +1,10 @@ from enum import StrEnum -from dify_graph.enums import NodeType +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) class CreatorUserRole(StrEnum): @@ -66,9 +70,9 @@ class AppTriggerStatus(StrEnum): class AppTriggerType(StrEnum): """App Trigger Type Enum""" - TRIGGER_WEBHOOK = NodeType.TRIGGER_WEBHOOK.value - TRIGGER_SCHEDULE = NodeType.TRIGGER_SCHEDULE.value - TRIGGER_PLUGIN = NodeType.TRIGGER_PLUGIN.value + TRIGGER_WEBHOOK = TRIGGER_WEBHOOK_NODE_TYPE + TRIGGER_SCHEDULE = TRIGGER_SCHEDULE_NODE_TYPE + TRIGGER_PLUGIN = TRIGGER_PLUGIN_NODE_TYPE # for backward compatibility UNKNOWN = "unknown" diff --git a/api/models/workflow.py b/api/models/workflow.py index 8c62292079..fdb8de0653 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -22,13 +22,14 @@ from sqlalchemy import ( from sqlalchemy.orm import Mapped, declared_attr, mapped_column from typing_extensions import deprecated +from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import NodeType, WorkflowExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus from dify_graph.file.constants import maybe_file_object from dify_graph.file.models import File from dify_graph.variables import utils as variable_utils @@ -269,12 +270,12 @@ class Workflow(Base): # bug loop_id = node_config.get("loop_id") if loop_id is None: raise _InvalidGraphDefinitionError("invalid graph") - return NodeType.LOOP, loop_id + return BuiltinNodeTypes.LOOP, loop_id elif in_iteration: iteration_id = node_config.get("iteration_id") if iteration_id is None: raise _InvalidGraphDefinitionError("invalid graph") - return NodeType.ITERATION, iteration_id + return BuiltinNodeTypes.ITERATION, iteration_id else: return None @@ -353,9 +354,7 @@ class Workflow(Base): # bug if specific_node_type: yield from ( - (node["id"], node["data"]) - for node in graph_dict["nodes"] - if node["data"]["type"] == specific_node_type.value + (node["id"], node["data"]) for node in graph_dict["nodes"] if node["data"]["type"] == specific_node_type ) else: yield from ((node["id"], node["data"]) for node in graph_dict["nodes"]) @@ -923,18 +922,18 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo extras: dict[str, Any] = {} execution_metadata = self.execution_metadata_dict if execution_metadata: - if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata: + if self.node_type == BuiltinNodeTypes.TOOL and "tool_info" in execution_metadata: tool_info: dict[str, Any] = execution_metadata["tool_info"] extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata: + elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata: datasource_info = execution_metadata["datasource_info"] extras["icon"] = datasource_info.get("icon") - elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata: - trigger_info = execution_metadata["trigger_info"] or {} + elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata: + trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {} provider_id = trigger_info.get("provider_id") if provider_id: extras["icon"] = TriggerManager.get_trigger_plugin_icon( diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index d3b2ede745..c044824a82 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -123,7 +123,7 @@ dify_graph/nodes/human_input/human_input_node.py dify_graph/nodes/if_else/if_else_node.py dify_graph/nodes/iteration/iteration_node.py dify_graph/nodes/knowledge_index/knowledge_index_node.py -dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py dify_graph/nodes/list_operator/node.py dify_graph/nodes/llm/node.py dify_graph/nodes/loop/loop_node.py diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 49ca273442..49e8b3cd60 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -20,14 +20,19 @@ from sqlalchemy.orm import Session from configs import dify_config from core.helper import ssrf_proxy from core.plugin.entities.plugin import PluginDependency -from dify_graph.enums import NodeType +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode +from dify_graph.enums import BuiltinNodeTypes from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from dify_graph.nodes.llm.entities import LLMNodeData from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData from dify_graph.nodes.tool.entities import ToolNodeData -from dify_graph.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory @@ -500,7 +505,7 @@ class AppDslService: unique_hash = None graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL: + if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ decrypted_id @@ -588,27 +593,27 @@ class AppDslService: if not node_data: continue data_type = node_data.get("type", "") - if data_type == NodeType.KNOWLEDGE_RETRIEVAL: + if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node_data["dataset_ids"] = [ cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node - if not include_secret and data_type == NodeType.TOOL: + if not include_secret and data_type == BuiltinNodeTypes.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node - if not include_secret and data_type == NodeType.AGENT: + if not include_secret and data_type == BuiltinNodeTypes.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) - if data_type == NodeType.TRIGGER_SCHEDULE.value: + if data_type == TRIGGER_SCHEDULE_NODE_TYPE: # override the config with the default config node_data["config"] = TriggerScheduleNode.get_default_config()["config"] - if data_type == NodeType.TRIGGER_WEBHOOK.value: + if data_type == TRIGGER_WEBHOOK_NODE_TYPE: # clear the webhook_url node_data["webhook_url"] = "" node_data["webhook_debug_url"] = "" - if data_type == NodeType.TRIGGER_PLUGIN.value: + if data_type == TRIGGER_PLUGIN_NODE_TYPE: # clear the subscription_id node_data["subscription_id"] = "" @@ -672,31 +677,31 @@ class AppDslService: try: typ = node.get("data", {}).get("type") match typ: - case NodeType.TOOL: + case BuiltinNodeTypes.TOOL: tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) - case NodeType.LLM: + case BuiltinNodeTypes.LLM: llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) - case NodeType.QUESTION_CLASSIFIER: + case BuiltinNodeTypes.QUESTION_CLASSIFIER: question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) - case NodeType.PARAMETER_EXTRACTOR: + case BuiltinNodeTypes.PARAMETER_EXTRACTOR: parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) - case NodeType.KNOWLEDGE_RETRIEVAL: + case BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: diff --git a/api/services/app_service.py b/api/services/app_service.py index b5e893c5b5..c5d1479a20 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -187,7 +187,10 @@ class AppService: for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool)) + typed_tool = {key: value for key, value in tool.items() if isinstance(key, str)} + if len(typed_tool) != len(tool): + continue + agent_tool_entity = AgentToolEntity.model_validate(typed_tool) # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index b9a565ec17..899a6ba378 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,13 +36,13 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.node_resolution import LATEST_VERSION, get_workflow_node_type_classes_mapping +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import ErrorStrategy, NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, SystemVariableKey from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent from dify_graph.graph_events.base import GraphNodeEventBase @@ -381,10 +381,10 @@ class RagPipelineService: """ # return default block config default_block_configs: list[dict[str, Any]] = [] - for node_type, node_class_mapping in get_workflow_node_type_classes_mapping().items(): + for node_type, node_class_mapping in get_node_type_classes_mapping().items(): node_class = node_class_mapping[LATEST_VERSION] filters = None - if node_type is NodeType.HTTP_REQUEST: + if node_type == BuiltinNodeTypes.HTTP_REQUEST: filters = { HTTP_REQUEST_CONFIG_FILTER_KEY: build_http_request_config( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -410,7 +410,7 @@ class RagPipelineService: :return: """ node_type_enum = NodeType(node_type) - node_mapping = get_workflow_node_type_classes_mapping() + node_mapping = get_node_type_classes_mapping() # return default block config if node_type_enum not in node_mapping: @@ -418,7 +418,7 @@ class RagPipelineService: node_class = node_mapping[node_type_enum][LATEST_VERSION] final_filters = dict(filters) if filters else {} - if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in final_filters: + if node_type_enum == BuiltinNodeTypes.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in final_filters: final_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, @@ -500,7 +500,7 @@ class RagPipelineService: session=session, app_id=pipeline.id, node_id=workflow_node_execution.node_id, - node_type=NodeType(workflow_node_execution.node_type), + node_type=workflow_node_execution.node_type, enclosing_node_id=enclosing_node_id, node_execution_id=workflow_node_execution.id, user=account, @@ -1262,7 +1262,7 @@ class RagPipelineService: session=session, app_id=pipeline.id, node_id=workflow_node_execution_db_model.node_id, - node_type=NodeType(workflow_node_execution_db_model.node_type), + node_type=workflow_node_execution_db_model.node_type, enclosing_node_id=enclosing_node_id, node_execution_id=workflow_node_execution.id, user=current_user, diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 58bb4b7c90..c7da1afe1b 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -22,10 +22,11 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency -from dify_graph.enums import NodeType +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from dify_graph.enums import BuiltinNodeTypes from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.datasource.entities import DatasourceNodeData -from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from dify_graph.nodes.llm.entities import LLMNodeData from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData @@ -287,7 +288,7 @@ class RagPipelineDslService: nodes = graph.get("nodes", []) dataset_id = None for node in nodes: - if node.get("data", {}).get("type") == "knowledge-index": + if node.get("data", {}).get("type") == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if ( dataset @@ -428,7 +429,7 @@ class RagPipelineDslService: nodes = graph.get("nodes", []) dataset_id = None for node in nodes: - if node.get("data", {}).get("type") == "knowledge-index": + if node.get("data", {}).get("type") == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if not dataset: dataset = Dataset( @@ -562,7 +563,7 @@ class RagPipelineDslService: graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL: + if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ decrypted_id @@ -696,17 +697,17 @@ class RagPipelineDslService: if not node_data: continue data_type = node_data.get("type", "") - if data_type == NodeType.KNOWLEDGE_RETRIEVAL: + if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node["data"]["dataset_ids"] = [ self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node - if not include_secret and data_type == NodeType.TOOL: + if not include_secret and data_type == BuiltinNodeTypes.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node - if not include_secret and data_type == NodeType.AGENT: + if not include_secret and data_type == BuiltinNodeTypes.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) @@ -740,35 +741,35 @@ class RagPipelineDslService: try: typ = node.get("data", {}).get("type") match typ: - case NodeType.TOOL: + case BuiltinNodeTypes.TOOL: tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) - case NodeType.DATASOURCE: + case BuiltinNodeTypes.DATASOURCE: datasource_entity = DatasourceNodeData.model_validate(node["data"]) if datasource_entity.provider_type != "local_file": dependencies.append(datasource_entity.plugin_id) - case NodeType.LLM: + case BuiltinNodeTypes.LLM: llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) - case NodeType.QUESTION_CLASSIFIER: + case BuiltinNodeTypes.QUESTION_CLASSIFIER: question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) - case NodeType.PARAMETER_EXTRACTOR: + case BuiltinNodeTypes.PARAMETER_EXTRACTOR: parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) - case NodeType.KNOWLEDGE_INDEX: + case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) if knowledge_index_entity.indexing_technique == "high_quality": if knowledge_index_entity.embedding_model_provider: @@ -789,7 +790,7 @@ class RagPipelineDslService: knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name ), ) - case NodeType.KNOWLEDGE_RETRIEVAL: + case BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 88b640305d..7e9d010d2f 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -5,15 +5,15 @@ from datetime import datetime from sqlalchemy import select from sqlalchemy.orm import Session -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.nodes import NodeType -from dify_graph.nodes.trigger_schedule.entities import ( +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from core.workflow.nodes.trigger_schedule.entities import ( ScheduleConfig, SchedulePlanUpdate, TriggerScheduleNodeData, VisualConfig, ) -from dify_graph.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError +from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError +from dify_graph.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan @@ -240,7 +240,7 @@ class ScheduleService: for node in nodes: node_data = node.get("data", {}) - if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value: + if node_data.get("type") != TRIGGER_SCHEDULE_NODE_TYPE: continue node_id = node.get("id", "start") diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 2343bbbd3d..24bbeda329 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -12,13 +12,13 @@ from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse from core.plugin.impl.exc import PluginNotFoundError +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.debug.events import PluginTriggerDebugEvent from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType -from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client from models.model import App @@ -179,7 +179,7 @@ class TriggerService: # Walk nodes to find plugin triggers nodes_in_graph: list[Mapping[str, Any]] = [] - for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN): + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): # Extract plugin trigger configuration from node plugin_id = node_config.get("plugin_id", "") provider_id = node_config.get("provider_id", "") diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 02977b934c..3c1a4cc747 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -16,15 +16,15 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.tool_file_manager import ToolFileManager -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import NodeType -from dify_graph.file.models import FileTransferMethod -from dify_graph.nodes.trigger_webhook.entities import ( +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.nodes.trigger_webhook.entities import ( ContentType, WebhookBodyParameter, WebhookData, WebhookParameter, ) +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.file.models import FileTransferMethod from dify_graph.variables.types import ArrayValidation, SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db @@ -862,7 +862,7 @@ class WebhookService: node_id: str webhook_id: str - nodes_id_in_graph = [node_id for node_id, _ in workflow.walk_nodes(NodeType.TRIGGER_WEBHOOK)] + nodes_id_in_graph = [node_id for node_id, _ in workflow.walk_nodes(TRIGGER_WEBHOOK_NODE_TYPE)] # Check webhook node limit if len(nodes_id_in_graph) > cls.MAX_WEBHOOK_NODES_PER_WORKFLOW: diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 3acbc93678..006483fe97 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -18,7 +18,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from dify_graph.file.models import FileUploadConfig from dify_graph.model_runtime.entities.llm_entities import LLMMode from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.variables.input_entities import VariableEntity from events.app_event import app_was_created from extensions.ext_database import db @@ -234,7 +234,7 @@ class WorkflowConverter: "position": None, "data": { "title": "START", - "type": NodeType.START, + "type": BuiltinNodeTypes.START, "variables": [jsonable_encoder(v) for v in variables], }, } @@ -296,7 +296,7 @@ class WorkflowConverter: "position": None, "data": { "title": f"HTTP REQUEST {api_based_extension.name}", - "type": NodeType.HTTP_REQUEST, + "type": BuiltinNodeTypes.HTTP_REQUEST, "method": "post", "url": api_based_extension.api_endpoint, "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, @@ -314,7 +314,7 @@ class WorkflowConverter: "position": None, "data": { "title": f"Parse {api_based_extension.name} Response", - "type": NodeType.CODE, + "type": BuiltinNodeTypes.CODE, "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], "code_language": "python3", "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" @@ -354,7 +354,7 @@ class WorkflowConverter: "position": None, "data": { "title": "KNOWLEDGE RETRIEVAL", - "type": NodeType.KNOWLEDGE_RETRIEVAL, + "type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "query_variable_selector": query_variable_selector, "dataset_ids": dataset_config.dataset_ids, "retrieval_mode": retrieve_config.retrieve_strategy.value, @@ -402,9 +402,9 @@ class WorkflowConverter: :param external_data_variable_node_mapping: external data variable node mapping """ # fetch start and knowledge retrieval node - start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START, graph["nodes"])) + start_node = next(filter(lambda n: n["data"]["type"] == BuiltinNodeTypes.START, graph["nodes"])) knowledge_retrieval_node = next( - filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL, graph["nodes"]), None + filter(lambda n: n["data"]["type"] == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, graph["nodes"]), None ) role_prefix = None @@ -523,7 +523,7 @@ class WorkflowConverter: "position": None, "data": { "title": "LLM", - "type": NodeType.LLM, + "type": BuiltinNodeTypes.LLM, "model": { "provider": model_config.provider, "name": model_config.model, @@ -578,7 +578,7 @@ class WorkflowConverter: "position": None, "data": { "title": "END", - "type": NodeType.END, + "type": BuiltinNodeTypes.END, "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], }, } @@ -592,7 +592,7 @@ class WorkflowConverter: return { "id": "answer", "position": None, - "data": {"title": "ANSWER", "type": NodeType.ANSWER, "answer": "{{#llm.text#}}"}, + "data": {"title": "ANSWER", "type": BuiltinNodeTypes.ANSWER, "answer": "{{#llm.text#}}"}, } def _create_edge(self, source: str, target: str): diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index b6f6fc5490..804bf28b66 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -14,10 +14,11 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.trigger.constants import is_trigger_node_type from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import SystemVariableKey +from dify_graph.enums import NodeType, SystemVariableKey from dify_graph.file.models import File -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables from dify_graph.variable_loader import VariableLoader from dify_graph.variables import Segment, StringSegment, VariableBase @@ -386,7 +387,7 @@ class WorkflowDraftVariableService: # # This implementation must remain synchronized with the `_build_from_variable_assigner_mapping` # and `save` methods. - if node_type == NodeType.VARIABLE_ASSIGNER: + if node_type == BuiltinNodeTypes.VARIABLE_ASSIGNER: return variable output_value = outputs_dict.get(variable.name, absent) else: @@ -753,8 +754,8 @@ class DraftVariableSaver: # technical variables from being exposed in the draft environment, particularly those # that aren't meant to be directly edited or viewed by users. _EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = { - NodeType.LLM: frozenset(["finish_reason"]), - NodeType.LOOP: frozenset(["loop_round"]), + BuiltinNodeTypes.LLM: frozenset(["finish_reason"]), + BuiltinNodeTypes.LOOP: frozenset(["loop_round"]), } # Database session used for persisting draft variables. @@ -815,7 +816,7 @@ class DraftVariableSaver: ) def _should_save_output_variables_for_draft(self) -> bool: - if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER: + if self._enclosing_node_id is not None and self._node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: # Currently we do not save output variables for nodes inside loop or iteration. return False return True @@ -1053,9 +1054,9 @@ class DraftVariableSaver: process_data = {} if not self._should_save_output_variables_for_draft(): return - if self._node_type == NodeType.VARIABLE_ASSIGNER: + if self._node_type == BuiltinNodeTypes.VARIABLE_ASSIGNER: draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data) - elif self._node_type == NodeType.START or self._node_type.is_trigger_node: + elif self._node_type == BuiltinNodeTypes.START or is_trigger_node_type(self._node_type): draft_vars = self._build_variables_from_start_mapping(outputs) else: draft_vars = self._build_variables_from_mapping(outputs) @@ -1071,7 +1072,7 @@ class DraftVariableSaver: @staticmethod def _should_variable_be_visible(node_id: str, node_type: NodeType, name: str) -> bool: - if node_type in NodeType.IF_ELSE: + if node_type == BuiltinNodeTypes.IF_ELSE: return False if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): return False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 5b24c356c2..319107d3fb 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -14,17 +14,23 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.node_resolution import LATEST_VERSION, get_workflow_node_type_classes_mapping +from core.trigger.constants import is_trigger_node_type +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file import File from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent from dify_graph.node_events import NodeRunResult -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config from dify_graph.nodes.human_input.entities import ( @@ -310,7 +316,7 @@ class WorkflowService: for _, node_data in draft_workflow.walk_nodes() if (node_type_str := node_data.get("type")) and isinstance(node_type_str, str) - and NodeType(node_type_str).is_trigger_node + and is_trigger_node_type(node_type_str) ) if trigger_node_count > 2: raise TriggerNodeLimitExceededError(count=trigger_node_count, limit=2) @@ -619,10 +625,10 @@ class WorkflowService: """ # return default block config default_block_configs: list[Mapping[str, object]] = [] - for node_type, node_class_mapping in get_workflow_node_type_classes_mapping().items(): + for node_type, node_class_mapping in get_node_type_classes_mapping().items(): node_class = node_class_mapping[LATEST_VERSION] filters = None - if node_type is NodeType.HTTP_REQUEST: + if node_type == BuiltinNodeTypes.HTTP_REQUEST: filters = { HTTP_REQUEST_CONFIG_FILTER_KEY: build_http_request_config( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -650,7 +656,7 @@ class WorkflowService: :return: """ node_type_enum = NodeType(node_type) - node_mapping = get_workflow_node_type_classes_mapping() + node_mapping = get_node_type_classes_mapping() # return default block config if node_type_enum not in node_mapping: @@ -658,7 +664,7 @@ class WorkflowService: node_class = node_mapping[node_type_enum][LATEST_VERSION] resolved_filters = dict(filters) if filters else {} - if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in resolved_filters: + if node_type_enum == BuiltinNodeTypes.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in resolved_filters: resolved_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, @@ -696,7 +702,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) node_data = node_config["data"] - if node_type.is_start_node: + if is_start_node_type(node_type): with Session(bind=db.engine) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) conversation_id = draft_var_srv.get_or_create_conversation( @@ -704,7 +710,7 @@ class WorkflowService: app=app_model, workflow=draft_workflow, ) - if node_type is NodeType.START: + if node_type == BuiltinNodeTypes.START: start_data = StartNodeData.model_validate(node_data, from_attributes=True) user_inputs = _rebuild_file_for_user_inputs_in_start_node( tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs @@ -783,7 +789,7 @@ class WorkflowService: session=session, app_id=app_model.id, node_id=workflow_node_execution.node_id, - node_type=NodeType(workflow_node_execution.node_type), + node_type=workflow_node_execution.node_type, enclosing_node_id=enclosing_node_id, node_execution_id=node_execution.id, user=account, @@ -816,7 +822,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: + if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") # inputs: values used to fill missing upstream variables referenced in form_content. @@ -875,7 +881,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: + if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") # inputs: values used to fill missing upstream variables referenced in form_content. @@ -915,7 +921,7 @@ class WorkflowService: session=session, app_id=app_model.id, node_id=node_id, - node_type=NodeType.HUMAN_INPUT, + node_type=BuiltinNodeTypes.HUMAN_INPUT, node_execution_id=str(uuid.uuid4()), user=account, enclosing_node_id=enclosing_node_id, @@ -940,7 +946,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: + if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) @@ -1328,18 +1334,18 @@ class WorkflowService: for node in node_configs: node_type = node.get("data", {}).get("type") if node_type: - node_types.add(NodeType(node_type)) + node_types.add(node_type) # start node and trigger node cannot coexist - if NodeType.START in node_types: - if any(nt.is_trigger_node for nt in node_types): + if BuiltinNodeTypes.START in node_types: + if any(is_trigger_node_type(nt) for nt in node_types): raise ValueError("Start node and trigger nodes cannot coexist in the same workflow") for node in node_configs: node_data = node.get("data", {}) node_type = node_data.get("type") - if node_type == NodeType.HUMAN_INPUT: + if node_type == BuiltinNodeTypes.HUMAN_INPUT: self._validate_human_input_node_data(node_data) def validate_features_structure(self, app_model: App, features: dict): @@ -1461,7 +1467,7 @@ def _setup_variable_pool( conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. - if node_type == NodeType.START or node_type.is_trigger_node: + if is_start_node_type(node_type): system_variable = SystemVariable( user_id=user_id, app_id=workflow.app_id, diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index e7f4e37c75..75ae1f6316 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -20,13 +20,14 @@ from core.db.session_factory import session_factory from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerInvokeEventResponse from core.plugin.impl.exc import PluginInvokeError +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager -from dify_graph.enums import NodeType, WorkflowExecutionStatus -from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData +from dify_graph.enums import WorkflowExecutionStatus from enums.quota_type import QuotaType, unlimited from models.enums import ( AppTriggerType, @@ -278,7 +279,7 @@ def dispatch_triggered_workflow( # Find the trigger node in the workflow event_node = None - for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN): + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): if node_id == plugin_trigger.node_id: event_node = node_config break diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index eaafbf99e3..466ef6c858 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -103,7 +103,7 @@ def _create_node_execution_from_domain( node_execution.index = execution.index node_execution.predecessor_node_id = execution.predecessor_node_id node_execution.node_id = execution.node_id - node_execution.node_type = execution.node_type.value + node_execution.node_type = execution.node_type node_execution.title = execution.title node_execution.node_execution_id = execution.node_execution_id diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index ced7ef973b..8c64d3ab27 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -3,7 +3,7 @@ import logging from celery import shared_task from core.db.session_factory import session_factory -from dify_graph.nodes.trigger_schedule.exc import ( +from core.workflow.nodes.trigger_schedule.exc import ( ScheduleExecutionError, ScheduleNotFoundError, TenantOwnerNotFoundError, diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index c043c7dc10..3e79792b5b 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,6 +1,7 @@ +from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.datasource.datasource_node import DatasourceNode class _Seg: @@ -28,13 +29,17 @@ class _GS: class _GP: - tenant_id = "t1" - app_id = "app-1" workflow_id = "wf-1" graph_config = {} - user_id = "u1" - user_from = "account" - invoke_from = "debugger" + run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "t1", + "app_id": "app-1", + "user_id": "u1", + "user_from": "account", + "invoke_from": "debugger", + } + } call_depth = 0 @@ -61,6 +66,8 @@ def test_node_integration_minimal_stream(mocker): def get_upload_file_by_id(cls, **_): raise AssertionError + mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) + node = DatasourceNode( id="n", config={ @@ -77,7 +84,6 @@ def test_node_integration_minimal_stream(mocker): }, graph_init_params=_GP(), graph_runtime_state=_GS(vp), - datasource_manager=_Mgr, ) out = list(node._run()) diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 7c4dcda2dc..b19b4ebdad 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -7,7 +7,7 @@ from sqlalchemy import delete from sqlalchemy.orm import Session from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.variables.segments import StringSegment from dify_graph.variables.types import SegmentType from dify_graph.variables.variables import StringVariable @@ -286,7 +286,7 @@ class TestDraftVariableLoader(unittest.TestCase): session=session, app_id=self._test_app_id, node_id="test_offload_node", - node_type=NodeType.LLM, # Use a real node type + node_type=BuiltinNodeTypes.LLM, # Use a real node type node_execution_id=node_execution_id, user=setup_account, ) @@ -542,7 +542,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): index=1, node_execution_id=str(uuid.uuid4()), node_id=self._node_id, - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", inputs='{"input": "test input"}', process_data='{"test_var": "process_value", "other_var": "other_process"}', diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index f8b7f95493..e3a2b6b866 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -60,7 +60,7 @@ def init_code_node(code_config: dict): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( id=str(uuid.uuid4()), diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 347fa9c9ed..f885f69e55 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -70,7 +70,7 @@ def init_http_node(config: dict): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( id=str(uuid.uuid4()), @@ -189,7 +189,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from dify_graph.enums import NodeType + from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, @@ -210,7 +210,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): # Create node data with custom auth and empty api_key node_data = HttpRequestNodeData( - type=NodeType.HTTP_REQUEST, + type=BuiltinNodeTypes.HTTP_REQUEST, title="http", desc="", url="http://example.com", @@ -717,7 +717,7 @@ def test_nested_object_variable_selector(setup_http_mock): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( id=str(uuid.uuid4()), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 970e2cae00..7bb4f905c3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -82,7 +82,7 @@ def test_execute_template_transform(): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") assert graph is not None node = TemplateTransformNode( diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 8a4fb8eda4..a6717ada31 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -54,7 +54,7 @@ def init_tool_node(config: dict): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index d783a08233..75471afef8 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -5,7 +5,7 @@ import pytest from faker import Faker from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 08f99cf55a..70d05792ce 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, @@ -68,7 +68,7 @@ def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) - inputs=[], user_actions=[], ).model_dump(mode="json") - node_data["type"] = NodeType.HUMAN_INPUT.value + node_data["type"] = BuiltinNodeTypes.HUMAN_INPUT graph = json.dumps({"nodes": [{"id": "human-node", "data": node_data}], "edges": []}) workflow = Workflow.new( diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index d8b43efeba..056db41750 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -860,8 +860,8 @@ class TestWorkflowService: # Act try: result = workflow_service.get_default_block_config(node_type=invalid_node_type) - # If we get here, the service should return None for invalid types - assert result is None + # If we get here, the service should return an empty config for invalid types. + assert result == {} except ValueError: # It's also acceptable for the service to raise a ValueError for invalid types pass @@ -1428,14 +1428,14 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus + from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_events import NodeRunSucceededEvent from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) - mock_node.node_type = NodeType.START + mock_node.node_type = BuiltinNodeTypes.START mock_node.title = "Test Node" mock_node.error_strategy = None @@ -1452,7 +1452,7 @@ class TestWorkflowService: mock_event = NodeRunSucceededEvent( id=str(uuid.uuid4()), node_id=node_id, - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, node_run_result=mock_result, start_at=datetime.now(), ) @@ -1473,9 +1473,9 @@ class TestWorkflowService: # Assert assert result is not None assert result.node_id == node_id - from dify_graph.enums import NodeType + from dify_graph.enums import BuiltinNodeTypes - assert result.node_type == NodeType.START # Should match the mock node type + assert result.node_type == BuiltinNodeTypes.START # Should match the mock node type assert result.title == "Test Node" # Import the enum for comparison from dify_graph.enums import WorkflowNodeExecutionStatus @@ -1503,14 +1503,14 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus + from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_events import NodeRunFailedEvent from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) - mock_node.node_type = NodeType.LLM + mock_node.node_type = BuiltinNodeTypes.LLM mock_node.title = "Test Node" mock_node.error_strategy = None @@ -1525,7 +1525,7 @@ class TestWorkflowService: mock_event = NodeRunFailedEvent( id=str(uuid.uuid4()), node_id=node_id, - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_run_result=mock_result, error="Test error message", start_at=datetime.now(), @@ -1572,14 +1572,14 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus + from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from dify_graph.graph_events import NodeRunFailedEvent from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node # Create mock node with continue_on_error mock_node = MagicMock(spec=Node) - mock_node.node_type = NodeType.TOOL + mock_node.node_type = BuiltinNodeTypes.TOOL mock_node.title = "Test Node" mock_node.error_strategy = ErrorStrategy.DEFAULT_VALUE mock_node.default_value_dict = {"default_output": "default_value"} @@ -1595,7 +1595,7 @@ class TestWorkflowService: mock_event = NodeRunFailedEvent( id=str(uuid.uuid4()), node_id=node_id, - node_type=NodeType.TOOL, + node_type=BuiltinNodeTypes.TOOL, node_run_result=mock_result, error="Test error message", start_at=datetime.now(), diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 7bfc6c9e13..4ea8d8c1c7 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -14,11 +14,16 @@ from sqlalchemy.orm import Session from configs import dify_config from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus @@ -48,10 +53,10 @@ WEBHOOK_ID_DEBUG = "whdebug1234567890123456" TEST_TRIGGER_URL = "https://trigger.example.com/base" -def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str: +def _build_workflow_graph(root_node_id: str, trigger_type: str) -> str: """Build a minimal workflow graph JSON for testing.""" - node_data: dict[str, Any] = {"type": trigger_type.value, "title": "trigger"} - if trigger_type == NodeType.TRIGGER_WEBHOOK: + node_data: dict[str, Any] = {"type": trigger_type, "title": "trigger"} + if trigger_type == TRIGGER_WEBHOOK_NODE_TYPE: node_data.update( { "method": "POST", @@ -64,7 +69,7 @@ def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str: graph = { "nodes": [ {"id": root_node_id, "data": node_data}, - {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + {"id": "answer-1", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "answer"}}, ], "edges": [{"source": root_node_id, "target": "answer-1", "sourceHandle": "success"}], } @@ -82,8 +87,8 @@ def test_publish_blocks_start_and_trigger_coexistence( graph = { "nodes": [ - {"id": "start", "data": {"type": NodeType.START.value}}, - {"id": "trig", "data": {"type": NodeType.TRIGGER_WEBHOOK.value}}, + {"id": "start", "data": {"type": BuiltinNodeTypes.START}}, + {"id": "trig", "data": {"type": TRIGGER_WEBHOOK_NODE_TYPE}}, ], "edges": [], } @@ -152,7 +157,7 @@ def test_webhook_trigger_creates_trigger_log( tenant, account = tenant_and_account webhook_node_id = "webhook-node" - graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + graph_json = _build_workflow_graph(webhook_node_id, TRIGGER_WEBHOOK_NODE_TYPE) published_workflow = Workflow.new( tenant_id=tenant.id, app_id=app_model.id, @@ -282,7 +287,7 @@ def test_schedule_visual_debug_poll_generates_event(monkeypatch: pytest.MonkeyPa node_config = { "id": "schedule-visual", "data": { - "type": NodeType.TRIGGER_SCHEDULE.value, + "type": TRIGGER_SCHEDULE_NODE_TYPE, "mode": "visual", "frequency": "daily", "visual_config": {"time": "3:00 PM"}, @@ -372,7 +377,7 @@ def test_webhook_debug_dispatches_event( """Webhook single-step debug should dispatch debug event and be pollable.""" tenant, account = tenant_and_account webhook_node_id = "webhook-debug-node" - graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + graph_json = _build_workflow_graph(webhook_node_id, TRIGGER_WEBHOOK_NODE_TYPE) draft_workflow = Workflow.new( tenant_id=tenant.id, app_id=app_model.id, @@ -443,7 +448,7 @@ def test_plugin_single_step_debug_flow( node_config = { "id": node_id, "data": { - "type": NodeType.TRIGGER_PLUGIN.value, + "type": TRIGGER_PLUGIN_NODE_TYPE, "title": "plugin", "plugin_id": "plugin-1", "plugin_unique_identifier": "plugin-1", @@ -519,14 +524,14 @@ def test_schedule_trigger_creates_trigger_log( { "id": schedule_node_id, "data": { - "type": NodeType.TRIGGER_SCHEDULE.value, + "type": TRIGGER_SCHEDULE_NODE_TYPE, "title": "schedule", "mode": "cron", "cron_expression": "0 9 * * *", "timezone": "UTC", }, }, - {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + {"id": "answer-1", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "answer"}}, ], "edges": [{"source": schedule_node_id, "target": "answer-1", "sourceHandle": "success"}], } @@ -639,7 +644,7 @@ def test_schedule_visual_cron_conversion( node_config: dict[str, Any] = { "id": "schedule-node", "data": { - "type": NodeType.TRIGGER_SCHEDULE.value, + "type": TRIGGER_SCHEDULE_NODE_TYPE, "mode": mode, "timezone": "UTC", }, @@ -680,7 +685,7 @@ def test_plugin_trigger_full_chain_with_db_verification( { "id": plugin_node_id, "data": { - "type": NodeType.TRIGGER_PLUGIN.value, + "type": TRIGGER_PLUGIN_NODE_TYPE, "title": "plugin", "plugin_id": "test-plugin", "plugin_unique_identifier": "test-plugin", @@ -690,7 +695,7 @@ def test_plugin_trigger_full_chain_with_db_verification( "parameters": {}, }, }, - {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + {"id": "answer-1", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "answer"}}, ], "edges": [{"source": plugin_node_id, "target": "answer-1", "sourceHandle": "success"}], } @@ -826,7 +831,7 @@ def test_plugin_debug_via_http_endpoint( node_config = { "id": node_id, "data": { - "type": NodeType.TRIGGER_PLUGIN.value, + "type": TRIGGER_PLUGIN_NODE_TYPE, "title": "plugin-debug", "plugin_id": "debug-plugin", "plugin_unique_identifier": "debug-plugin", diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 67f87710a1..0a244b3fea 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -42,7 +42,7 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from models.enums import MessageStatus @@ -226,7 +226,7 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline._save_output_for_event = lambda event, node_execution_id: None event = SimpleNamespace( - node_type=NodeType.ANSWER, + node_type=BuiltinNodeTypes.ANSWER, outputs={"k": "v"}, node_execution_id="exec", node_id="node", @@ -254,7 +254,7 @@ class TestAdvancedChatGenerateTaskPipeline: iter_start = QueueIterationStartEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", start_at=datetime.utcnow(), node_run_index=1, @@ -263,14 +263,14 @@ class TestAdvancedChatGenerateTaskPipeline: index=1, node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", node_run_index=1, ) iter_done = QueueIterationCompletedEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", start_at=datetime.utcnow(), node_run_index=1, @@ -278,7 +278,7 @@ class TestAdvancedChatGenerateTaskPipeline: loop_start = QueueLoopStartEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", start_at=datetime.utcnow(), node_run_index=1, @@ -287,14 +287,14 @@ class TestAdvancedChatGenerateTaskPipeline: index=1, node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", node_run_index=1, ) loop_done = QueueLoopCompletedEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", start_at=datetime.utcnow(), node_run_index=1, @@ -358,7 +358,7 @@ class TestAdvancedChatGenerateTaskPipeline: failed_event = QueueNodeFailedEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.utcnow(), inputs={}, outputs={}, @@ -368,7 +368,7 @@ class TestAdvancedChatGenerateTaskPipeline: exc_event = QueueNodeExceptionEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.utcnow(), inputs={}, outputs={}, @@ -462,7 +462,7 @@ class TestAdvancedChatGenerateTaskPipeline: filled_event = QueueHumanInputFormFilledEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="title", rendered_content="content", action_id="action", @@ -470,7 +470,7 @@ class TestAdvancedChatGenerateTaskPipeline: ) timeout_event = QueueHumanInputFormTimeoutEvent( node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="title", expiration_time=datetime.utcnow(), ) @@ -589,7 +589,7 @@ class TestAdvancedChatGenerateTaskPipeline: event = QueueNodeExceptionEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.utcnow(), inputs={}, outputs={}, diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 69d476bd13..aba7dfff8c 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -24,7 +24,7 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account @@ -66,7 +66,7 @@ class TestWorkflowResponseConverter: node_execution_id=node_execution_id or str(uuid.uuid4()), node_id="test-node-id", node_title="Test Node", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, start_at=naive_utc_now(), in_iteration_id=None, in_loop_id=None, @@ -83,7 +83,7 @@ class TestWorkflowResponseConverter: """Create a QueueNodeSucceededEvent for testing.""" return QueueNodeSucceededEvent( node_id="test-node-id", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_execution_id=node_execution_id, start_at=naive_utc_now(), in_iteration_id=None, @@ -108,7 +108,7 @@ class TestWorkflowResponseConverter: error="oops", retry_index=1, node_id="test-node-id", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_title="test code", provider_type="built-in", provider_id="code", @@ -319,7 +319,7 @@ class TestWorkflowResponseConverter: iteration_event = QueueNodeSucceededEvent( node_id="iteration-node", - node_type=NodeType.ITERATION, + node_type=BuiltinNodeTypes.ITERATION, node_execution_id=str(uuid.uuid4()), start_at=naive_utc_now(), in_iteration_id=None, @@ -336,7 +336,7 @@ class TestWorkflowResponseConverter: ) assert response is None - loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP}) + loop_event = iteration_event.model_copy(update={"node_type": BuiltinNodeTypes.LOOP}) response = converter.workflow_node_finish_to_stream_response( event=loop_event, task_id="test-task-id", @@ -478,7 +478,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=large_value, process_data=large_value, @@ -523,7 +523,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=large_value, process_data=large_value, @@ -562,7 +562,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=large_value, process_data=large_value, @@ -600,7 +600,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: return QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=inputs, process_data=process_data, @@ -614,7 +614,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: return QueueNodeFailedEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=inputs, process_data=process_data, @@ -628,7 +628,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: return QueueNodeExceptionEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=inputs, process_data=process_data, @@ -690,7 +690,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: start_event = QueueNodeStartedEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="Test Node", node_run_index=1, start_at=naive_utc_now(), @@ -706,7 +706,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeRetryEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="Test Node", node_run_index=1, start_at=naive_utc_now(), @@ -748,7 +748,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: start_event = QueueIterationStartEvent( node_execution_id="test_iter_exec_id", node_id="test_iteration", - node_type=NodeType.ITERATION, + node_type=BuiltinNodeTypes.ITERATION, node_title="Test Iteration", node_run_index=0, start_at=naive_utc_now(), @@ -776,7 +776,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: start_event = QueueLoopStartEvent( node_execution_id="test_loop_exec_id", node_id="test_loop", - node_type=NodeType.LOOP, + node_type=BuiltinNodeTypes.LOOP, node_title="Test Loop", start_at=naive_utc_now(), inputs=large_inputs, @@ -806,7 +806,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=large_inputs, process_data=large_process_data, diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 8f1baaa1e4..a3ced02394 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -479,7 +479,7 @@ class TestBaseAppGeneratorExtras: def test_get_draft_var_saver_factory_debugger(self): from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.enums import NodeType + from dify_graph.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() @@ -492,7 +492,7 @@ class TestBaseAppGeneratorExtras: session=MagicMock(), app_id="app-id", node_id="node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, node_execution_id="node-exec-id", ) diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 4f67d9cb56..2f73a8cda8 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -12,7 +12,7 @@ from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -44,12 +44,12 @@ if "core.ops.ops_trace_manager" not in sys.modules: class _StubToolNodeData(BaseNodeData): - type: NodeType = NodeType.TOOL + type: NodeType = BuiltinNodeTypes.TOOL pause_on: bool = False class _StubToolNode(Node[_StubToolNodeData]): - node_type = NodeType.TOOL + node_type = BuiltinNodeTypes.TOOL @classmethod def version(cls) -> str: @@ -94,7 +94,7 @@ def _patch_tool_node(mocker): def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: typed_node_config = NodeConfigDictAdapter.validate_python(node_config) node_data = typed_node_config["data"] - if node_data.type == NodeType.TOOL: + if node_data.type == BuiltinNodeTypes.TOOL: return _StubToolNode( id=str(typed_node_config["id"]), config=typed_node_config, @@ -108,7 +108,7 @@ def _patch_tool_node(mocker): def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: node_data = data.model_dump() - node_data["type"] = node_type.value + node_data["type"] = str(node_type) return node_data @@ -124,11 +124,11 @@ def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: ) nodes = [ - {"id": "start", "data": _node_data(NodeType.START, start_data)}, - {"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)}, - {"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)}, - {"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)}, - {"id": "end", "data": _node_data(NodeType.END, end_data)}, + {"id": "start", "data": _node_data(BuiltinNodeTypes.START, start_data)}, + {"id": "tool_a", "data": _node_data(BuiltinNodeTypes.TOOL, tool_data_a)}, + {"id": "tool_b", "data": _node_data(BuiltinNodeTypes.TOOL, tool_data_b)}, + {"id": "tool_c", "data": _node_data(BuiltinNodeTypes.TOOL, tool_data_c)}, + {"id": "end", "data": _node_data(BuiltinNodeTypes.END, end_data)}, ] edges = [ {"source": "start", "target": "tool_a"}, @@ -157,7 +157,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G graph_runtime_state=runtime_state, ) - return Graph.init(graph_config=graph_config, node_factory=node_factory) + return Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") def _build_runtime_state(run_id: str) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 108b740344..3f1dd14569 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -17,7 +17,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, @@ -193,7 +193,7 @@ class TestWorkflowBasedAppRunner: NodeRunStartedEvent( id="exec", node_id="node", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, node_title="Start", start_at=datetime.utcnow(), ), @@ -203,7 +203,7 @@ class TestWorkflowBasedAppRunner: NodeRunStreamChunkEvent( id="exec", node_id="node", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, selector=["node", "text"], chunk="hi", is_final=False, @@ -214,7 +214,7 @@ class TestWorkflowBasedAppRunner: NodeRunAgentLogEvent( id="exec", node_id="node", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, message_id="msg", label="label", node_execution_id="exec", @@ -230,7 +230,7 @@ class TestWorkflowBasedAppRunner: NodeRunIterationSucceededEvent( id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="Iter", start_at=datetime.utcnow(), inputs={}, @@ -244,7 +244,7 @@ class TestWorkflowBasedAppRunner: NodeRunLoopFailedEvent( id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="Loop", start_at=datetime.utcnow(), inputs={}, diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index b37f7a8120..f35710d207 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -44,7 +44,7 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import NodeType, WorkflowExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowExecutionStatus from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from models.enums import CreatorUserRole @@ -190,7 +190,7 @@ class TestWorkflowGenerateTaskPipeline: event = QueueNodeSucceededEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, start_at=datetime.utcnow(), inputs={}, outputs={}, @@ -243,7 +243,7 @@ class TestWorkflowGenerateTaskPipeline: event = QueueNodeFailedEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, start_at=datetime.utcnow(), inputs={}, outputs={}, @@ -300,7 +300,7 @@ class TestWorkflowGenerateTaskPipeline: iter_start = QueueIterationStartEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", start_at=datetime.utcnow(), node_run_index=1, @@ -309,14 +309,14 @@ class TestWorkflowGenerateTaskPipeline: index=1, node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", node_run_index=1, ) iter_done = QueueIterationCompletedEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", start_at=datetime.utcnow(), node_run_index=1, @@ -324,7 +324,7 @@ class TestWorkflowGenerateTaskPipeline: loop_start = QueueLoopStartEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", start_at=datetime.utcnow(), node_run_index=1, @@ -333,14 +333,14 @@ class TestWorkflowGenerateTaskPipeline: index=1, node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", node_run_index=1, ) loop_done = QueueLoopCompletedEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="LLM", start_at=datetime.utcnow(), node_run_index=1, @@ -348,7 +348,7 @@ class TestWorkflowGenerateTaskPipeline: filled_event = QueueHumanInputFormFilledEvent( node_execution_id="exec", node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="title", rendered_content="content", action_id="action", @@ -356,7 +356,7 @@ class TestWorkflowGenerateTaskPipeline: ) timeout_event = QueueHumanInputFormTimeoutEvent( node_id="node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="title", expiration_time=datetime.utcnow(), ) @@ -645,7 +645,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_title="title", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_run_index=1, start_at=datetime.utcnow(), provider_type="provider", @@ -657,7 +657,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_title="title", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_run_index=1, start_at=datetime.utcnow(), provider_type="provider", @@ -683,7 +683,7 @@ class TestWorkflowGenerateTaskPipeline: event = QueueNodeExceptionEvent( node_execution_id="exec-id", node_id="node", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, start_at=datetime.utcnow(), inputs={}, outputs={}, @@ -855,7 +855,7 @@ class TestWorkflowGenerateTaskPipeline: event = QueueNodeSucceededEvent( node_execution_id="exec-id", node_id="node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, in_loop_id="loop-id", start_at=datetime.utcnow(), process_data={"k": "v"}, diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index 7d0e1d25f6..bdc889d941 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -4,7 +4,7 @@ from unittest.mock import Mock from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from dify_graph.graph_engine.protocols.command_channel import CommandChannel from dify_graph.graph_events.node import NodeRunSucceededEvent from dify_graph.node_events import NodeRunResult @@ -78,7 +78,7 @@ def test_persists_conversation_variables_from_assigner_output(): layer = ConversationVariablePersistenceLayer(updater) layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) layer.on_event(event) updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) @@ -100,7 +100,7 @@ def test_skips_when_outputs_missing(): layer = ConversationVariablePersistenceLayer(updater) layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER) + event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER) layer.on_event(event) updater.update.assert_not_called() @@ -112,7 +112,7 @@ def test_skips_non_assigner_nodes(): layer = ConversationVariablePersistenceLayer(updater) layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=NodeType.LLM) + event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.LLM) layer.on_event(event) updater.update.assert_not_called() @@ -137,7 +137,7 @@ def test_skips_non_conversation_variables(): layer = ConversationVariablePersistenceLayer(updater) layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) layer.on_event(event) updater.update.assert_not_called() diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index fac0597f5a..dfd61acfa7 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -35,7 +35,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: @@ -413,7 +413,7 @@ def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrac monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) - node_execution.node_type = NodeType.LLM + node_execution.node_type = BuiltinNodeTypes.LLM assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "llm" @@ -426,7 +426,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) - node_execution.node_type = NodeType.KNOWLEDGE_RETRIEVAL + node_execution.node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "retrieval" @@ -437,7 +437,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) - node_execution.node_type = NodeType.TOOL + node_execution.node_type = BuiltinNodeTypes.TOOL assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "tool" @@ -448,7 +448,7 @@ def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTra monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) - node_execution.node_type = NodeType.CODE + node_execution.node_type = BuiltinNodeTypes.CODE assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "task" @@ -460,7 +460,7 @@ def test_build_workflow_node_span_handles_errors( trace_metadata = MagicMock() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) - node_execution.node_type = NodeType.CODE + node_execution.node_type = BuiltinNodeTypes.CODE assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) is None assert "Error occurred in build_workflow_node_span" in caplog.text diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 8e036e4b52..0ff135562c 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -25,7 +25,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -147,7 +147,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_llm = MagicMock() node_llm.id = "node-llm" node_llm.title = "LLM Node" - node_llm.node_type = NodeType.LLM + node_llm.node_type = BuiltinNodeTypes.LLM node_llm.status = "succeeded" node_llm.process_data = { "model_mode": "chat", @@ -164,7 +164,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other = MagicMock() node_other.id = "node-other" node_other.title = "Other Node" - node_other.node_type = NodeType.CODE + node_other.node_type = BuiltinNodeTypes.CODE node_other.status = "failed" node_other.process_data = None node_other.inputs = {"code": "print"} @@ -664,7 +664,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat node = MagicMock() node.id = "n1" node.title = "LLM Node" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.status = "succeeded" class BadDict(collections.UserDict): diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index 98f9dd00cf..f656f7435f 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -21,7 +21,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -145,7 +145,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_llm = MagicMock() node_llm.id = "node-llm" node_llm.title = "LLM Node" - node_llm.node_type = NodeType.LLM + node_llm.node_type = BuiltinNodeTypes.LLM node_llm.status = "succeeded" node_llm.process_data = { "model_mode": "chat", @@ -162,7 +162,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_other = MagicMock() node_other.id = "node-other" node_other.title = "Tool Node" - node_other.node_type = NodeType.TOOL + node_other.node_type = BuiltinNodeTypes.TOOL node_other.status = "succeeded" node_other.process_data = None node_other.inputs = {"tool_input": "val"} @@ -174,7 +174,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_retrieval = MagicMock() node_retrieval.id = "node-retrieval" node_retrieval.title = "Retrieval Node" - node_retrieval.node_type = NodeType.KNOWLEDGE_RETRIEVAL + node_retrieval.node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL node_retrieval.status = "succeeded" node_retrieval.process_data = None node_retrieval.inputs = {"query": "val"} @@ -555,7 +555,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl node_llm = MagicMock() node_llm.id = "node-llm" node_llm.title = "LLM Node" - node_llm.node_type = NodeType.LLM + node_llm.node_type = BuiltinNodeTypes.LLM node_llm.status = "succeeded" node_llm.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) node_llm.inputs = {} diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index 0657acc1d9..cccedaa08c 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -161,7 +161,7 @@ def _make_node(**overrides): "tenant_id": "t1", "app_id": "app-1", "title": "Node Title", - "node_type": NodeType.CODE, + "node_type": BuiltinNodeTypes.CODE, "status": "succeeded", "inputs": '{"key": "value"}', "outputs": '{"result": "ok"}', @@ -362,7 +362,7 @@ class TestWorkflowTrace: def test_workflow_with_llm_node(self, trace_instance, mock_tracing, mock_db): llm_node = _make_node( - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, process_data=json.dumps( { "prompts": [{"role": "user", "text": "hi"}], @@ -388,7 +388,7 @@ class TestWorkflowTrace: def test_workflow_with_question_classifier_node(self, trace_instance, mock_tracing, mock_db): qc_node = _make_node( - node_type=NodeType.QUESTION_CLASSIFIER, + node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER, process_data=json.dumps( { "prompts": "classify this", @@ -408,7 +408,7 @@ class TestWorkflowTrace: def test_workflow_with_http_request_node(self, trace_instance, mock_tracing, mock_db): http_node = _make_node( - node_type=NodeType.HTTP_REQUEST, + node_type=BuiltinNodeTypes.HTTP_REQUEST, process_data='{"url": "https://api.com"}', ) mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] @@ -424,7 +424,7 @@ class TestWorkflowTrace: def test_workflow_with_knowledge_retrieval_node(self, trace_instance, mock_tracing, mock_db): kr_node = _make_node( - node_type=NodeType.KNOWLEDGE_RETRIEVAL, + node_type=BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, outputs=json.dumps( { "result": [ @@ -846,13 +846,13 @@ class TestGetNodeSpanType: @pytest.mark.parametrize( ("node_type", "expected_contains"), [ - (NodeType.LLM, "LLM"), - (NodeType.QUESTION_CLASSIFIER, "LLM"), - (NodeType.KNOWLEDGE_RETRIEVAL, "RETRIEVER"), - (NodeType.TOOL, "TOOL"), - (NodeType.CODE, "TOOL"), - (NodeType.HTTP_REQUEST, "TOOL"), - (NodeType.AGENT, "AGENT"), + (BuiltinNodeTypes.LLM, "LLM"), + (BuiltinNodeTypes.QUESTION_CLASSIFIER, "LLM"), + (BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "RETRIEVER"), + (BuiltinNodeTypes.TOOL, "TOOL"), + (BuiltinNodeTypes.CODE, "TOOL"), + (BuiltinNodeTypes.HTTP_REQUEST, "TOOL"), + (BuiltinNodeTypes.AGENT, "AGENT"), ], ) def test_mapped_types(self, trace_instance, node_type, expected_contains): diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index 80a0331c4b..b2cb7d5109 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -18,7 +18,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -172,7 +172,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_llm = MagicMock() node_llm.id = LLM_NODE_ID node_llm.title = "LLM Node" - node_llm.node_type = NodeType.LLM + node_llm.node_type = BuiltinNodeTypes.LLM node_llm.status = "succeeded" node_llm.process_data = { "model_mode": "chat", @@ -189,7 +189,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other = MagicMock() node_other.id = CODE_NODE_ID node_other.title = "Other Node" - node_other.node_type = NodeType.CODE + node_other.node_type = BuiltinNodeTypes.CODE node_other.status = "failed" node_other.process_data = None node_other.inputs = {"code": "print"} @@ -641,7 +641,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch node = MagicMock() node.id = "88e8e918-472e-4b69-8051-12502c34fc07" node.title = "LLM Node" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.status = "succeeded" class BadDict(collections.UserDict): diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index 077a92d866..f259e4639f 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -15,7 +15,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.tencent_trace.tencent_trace import TencentDataTrace from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) @@ -320,10 +320,10 @@ class TestTencentDataTrace: node1 = MagicMock(spec=WorkflowNodeExecution) node1.id = "n1" - node1.node_type = NodeType.LLM + node1.node_type = BuiltinNodeTypes.LLM node2 = MagicMock(spec=WorkflowNodeExecution) node2.id = "n2" - node2.node_type = NodeType.TOOL + node2.node_type = BuiltinNodeTypes.TOOL with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]): with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]): @@ -359,10 +359,10 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) nodes = [ - (NodeType.LLM, mock_span_builder.build_workflow_llm_span), - (NodeType.KNOWLEDGE_RETRIEVAL, mock_span_builder.build_workflow_retrieval_span), - (NodeType.TOOL, mock_span_builder.build_workflow_tool_span), - (NodeType.CODE, mock_span_builder.build_workflow_task_span), + (BuiltinNodeTypes.LLM, mock_span_builder.build_workflow_llm_span), + (BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, mock_span_builder.build_workflow_retrieval_span), + (BuiltinNodeTypes.TOOL, mock_span_builder.build_workflow_tool_span), + (BuiltinNodeTypes.CODE, mock_span_builder.build_workflow_task_span), ] for node_type, builder_method in nodes: @@ -377,7 +377,7 @@ class TestTencentDataTrace: def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder): node = MagicMock(spec=WorkflowNodeExecution) - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.id = "n1" mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 32389b4d64..49d6b698ef 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,29 +1,29 @@ from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from dify_graph.enums import NodeType +from dify_graph.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: """Tests for _get_node_span_kind helper.""" def test_all_node_types_are_mapped_correctly(self): - """Ensure every NodeType enum member is mapped to the correct span kind.""" + """Ensure every built-in node type is mapped to the correct span kind.""" # Mappings for node types that have a specialised span kind. special_mappings = { - NodeType.LLM: OpenInferenceSpanKindValues.LLM, - NodeType.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER, - NodeType.TOOL: OpenInferenceSpanKindValues.TOOL, - NodeType.AGENT: OpenInferenceSpanKindValues.AGENT, + BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER, + BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL, + BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT, } - # Test that every NodeType enum member is mapped to the correct span kind. + # Test that every built-in node type is mapped to the correct span kind. # Node types not in `special_mappings` should default to CHAIN. - for node_type in NodeType: + for node_type in BUILT_IN_NODE_TYPES: expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN) actual_span_kind = _get_node_span_kind(node_type) assert actual_span_kind == expected_span_kind, ( - f"NodeType.{node_type.name} was mapped to {actual_span_kind}, but {expected_span_kind} was expected." + f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected." ) def test_unknown_string_defaults_to_chain(self): diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index cdd97d5369..8057bbbad5 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.weave_trace import WeaveDataTrace -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -173,7 +173,7 @@ def _make_node(**overrides): defaults = { "id": "node-1", "title": "Node Title", - "node_type": NodeType.CODE, + "node_type": BuiltinNodeTypes.CODE, "status": "succeeded", "inputs": {"key": "value"}, "outputs": {"result": "ok"}, @@ -633,7 +633,7 @@ class TestWorkflowTrace: """Workflow trace iterates node executions and creates node runs.""" node = _make_node( id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, inputs={"k": "v"}, outputs={"r": "ok"}, elapsed_time=0.5, @@ -655,7 +655,7 @@ class TestWorkflowTrace: def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch): """LLM node uses process_data prompts as inputs.""" node = _make_node( - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, process_data={ "prompts": [{"role": "user", "content": "hi"}], "model_mode": "chat", @@ -683,7 +683,7 @@ class TestWorkflowTrace: def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch): """Non-LLM node uses node_execution.inputs directly.""" node = _make_node( - node_type=NodeType.TOOL, + node_type=BuiltinNodeTypes.TOOL, inputs={"tool_input": "val"}, process_data=None, ) @@ -743,7 +743,7 @@ class TestWorkflowTrace: def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch): """Chat mode LLM node adds ls_provider and ls_model_name to attributes.""" node = _make_node( - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, process_data={"model_mode": "chat", "model_provider": "openai", "model_name": "gpt-4", "prompts": []}, ) self._setup_repo(monkeypatch, nodes=[node]) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index de3ccc4518..d61f01c616 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -32,10 +32,10 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.workflow.nodes.knowledge_retrieval import exc +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.nodes.knowledge_retrieval import exc -from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest from models.dataset import Dataset # ==================== Helper Functions ==================== diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 682a451117..48782515d0 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -5,8 +5,8 @@ import pytest from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.nodes.knowledge_retrieval import exc -from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest +from core.workflow.nodes.knowledge_retrieval import exc +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset # ==================== Helper Functions ==================== diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index b613573927..2a83a4e802 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -15,7 +15,7 @@ from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser @@ -61,7 +61,7 @@ def sample_workflow_node_execution(): workflow_execution_id=str(uuid4()), index=1, node_id="test_node", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", inputs={"input1": "value1"}, status=WorkflowNodeExecutionStatus.RUNNING, @@ -259,7 +259,7 @@ class TestCeleryWorkflowNodeExecutionRepository: workflow_execution_id=workflow_run_id, index=1, node_id="node1", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Node 1", inputs={"input1": "value1"}, status=WorkflowNodeExecutionStatus.RUNNING, @@ -272,7 +272,7 @@ class TestCeleryWorkflowNodeExecutionRepository: workflow_execution_id=workflow_run_id, index=2, node_id="node2", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Node 2", inputs={"input2": "value2"}, status=WorkflowNodeExecutionStatus.RUNNING, @@ -310,7 +310,7 @@ class TestCeleryWorkflowNodeExecutionRepository: workflow_execution_id=workflow_run_id, index=2, node_id="node2", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Node 2", inputs={}, status=WorkflowNodeExecutionStatus.RUNNING, @@ -323,7 +323,7 @@ class TestCeleryWorkflowNodeExecutionRepository: workflow_execution_id=workflow_run_id, index=1, node_id="node1", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Node 1", inputs={}, status=WorkflowNodeExecutionStatus.RUNNING, diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index bae5bae06d..456c3dde12 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -14,7 +14,7 @@ from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom @@ -70,7 +70,7 @@ class TestWorkflowNodeExecutionConflictHandling: workflow_execution_id="test-workflow-execution-id", node_execution_id="test-node-execution-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", index=1, status=WorkflowNodeExecutionStatus.RUNNING, @@ -108,7 +108,7 @@ class TestWorkflowNodeExecutionConflictHandling: workflow_execution_id="test-workflow-execution-id", node_execution_id="test-node-execution-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", index=1, status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -153,7 +153,7 @@ class TestWorkflowNodeExecutionConflictHandling: workflow_execution_id="test-workflow-execution-id", node_execution_id="test-node-execution-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", index=1, status=WorkflowNodeExecutionStatus.RUNNING, @@ -195,7 +195,7 @@ class TestWorkflowNodeExecutionConflictHandling: workflow_execution_id="test-workflow-execution-id", node_execution_id="test-node-execution-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", index=1, status=WorkflowNodeExecutionStatus.RUNNING, diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index c880b8d41b..eeab81a178 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -13,6 +13,7 @@ from unittest.mock import MagicMock from sqlalchemy import Engine +from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) @@ -20,7 +21,7 @@ from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload @@ -41,7 +42,7 @@ class TruncationTestCase: def create_test_cases() -> list[TruncationTestCase]: """Create test cases for different truncation scenarios.""" # Create large data that will definitely exceed the threshold (10KB) - large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1000)} + large_data = {"data": "x" * (dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE + 1000)} small_data = {"data": "small"} return [ @@ -101,7 +102,7 @@ def create_workflow_node_execution( workflow_execution_id="test-workflow-execution-id", index=1, node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", inputs=inputs, outputs=outputs, @@ -145,7 +146,7 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation: db_model.index = 1 db_model.predecessor_node_id = None db_model.node_id = "node-id" - db_model.node_type = NodeType.LLM + db_model.node_type = BuiltinNodeTypes.LLM db_model.title = "Test Node" db_model.inputs = json.dumps({"value": "inputs"}) db_model.process_data = json.dumps({"value": "process_data"}) diff --git a/api/tests/unit_tests/core/test_trigger_debug_event_selectors.py b/api/tests/unit_tests/core/test_trigger_debug_event_selectors.py index 14b42adbbe..2b508ca654 100644 --- a/api/tests/unit_tests/core/test_trigger_debug_event_selectors.py +++ b/api/tests/unit_tests/core/test_trigger_debug_event_selectors.py @@ -6,7 +6,7 @@ import pytest import pytz from core.trigger.debug import event_selectors -from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig +from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig class _DummyRedis: diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index 331bcd6c25..bcb1d745e3 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -13,6 +13,11 @@ from unittest.mock import MagicMock, patch import pytest from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) from core.trigger.debug.event_selectors import ( PluginTriggerDebugEventPoller, ScheduleTriggerDebugEventPoller, @@ -21,7 +26,7 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID @@ -215,24 +220,24 @@ class TestCreateEventPoller: return wf def test_creates_plugin_poller(self): - wf = self._workflow_with_node(NodeType.TRIGGER_PLUGIN) + wf = self._workflow_with_node(TRIGGER_PLUGIN_NODE_TYPE) poller = create_event_poller(wf, "t1", "u1", "a1", "n1") assert isinstance(poller, PluginTriggerDebugEventPoller) def test_creates_webhook_poller(self): - wf = self._workflow_with_node(NodeType.TRIGGER_WEBHOOK) + wf = self._workflow_with_node(TRIGGER_WEBHOOK_NODE_TYPE) poller = create_event_poller(wf, "t1", "u1", "a1", "n1") assert isinstance(poller, WebhookTriggerDebugEventPoller) def test_creates_schedule_poller(self): - wf = self._workflow_with_node(NodeType.TRIGGER_SCHEDULE) + wf = self._workflow_with_node(TRIGGER_SCHEDULE_NODE_TYPE) poller = create_event_poller(wf, "t1", "u1", "a1", "n1") assert isinstance(poller, ScheduleTriggerDebugEventPoller) def test_raises_for_unknown_type(self): wf = MagicMock() wf.get_node_config_by_id.return_value = {"data": {}} - wf.get_node_type_from_node_config.return_value = NodeType.START + wf.get_node_type_from_node_config.return_value = BuiltinNodeTypes.START with pytest.raises(ValueError): create_event_poller(wf, "t1", "u1", "a1", "n1") @@ -249,7 +254,7 @@ class TestSelectTriggerDebugEvents: def test_returns_first_non_none_event(self): wf = MagicMock() wf.get_node_config_by_id.return_value = {"data": {}} - wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK + wf.get_node_type_from_node_config.return_value = TRIGGER_WEBHOOK_NODE_TYPE app_model = MagicMock() app_model.tenant_id = "t1" app_model.id = "a1" @@ -265,7 +270,7 @@ class TestSelectTriggerDebugEvents: def test_returns_none_when_no_events(self): wf = MagicMock() wf.get_node_config_by_id.return_value = {"data": {}} - wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK + wf.get_node_type_from_node_config.return_value = TRIGGER_WEBHOOK_NODE_TYPE app_model = MagicMock() app_model.tenant_id = "t1" app_model.id = "a1" diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py index 4035c1a871..216e64db8d 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py @@ -9,7 +9,7 @@ from typing import Any import pytest from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes class TestWorkflowNodeExecutionProcessDataTruncation: @@ -25,7 +25,7 @@ class TestWorkflowNodeExecutionProcessDataTruncation: workflow_id="test-workflow-id", index=1, node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", process_data=process_data, created_at=datetime.now(), @@ -212,7 +212,7 @@ class TestWorkflowNodeExecutionProcessDataScenarios: workflow_id="test-workflow-id", index=1, node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", process_data=scenario.original_data, created_at=datetime.now(), diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py index c46b9e51fd..24bd9ccbed 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph.py @@ -2,7 +2,7 @@ from unittest.mock import Mock -from dify_graph.enums import NodeExecutionType, NodeState, NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState from dify_graph.graph.edge import Edge from dify_graph.graph.graph import Graph from dify_graph.nodes.base.node import Node @@ -14,7 +14,7 @@ def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: Nod node.id = node_id node.execution_type = execution_type node.state = state - node.node_type = NodeType.START + node.node_type = BuiltinNodeTypes.START return node diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py index bd4a0f32e2..64c2eee776 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock import pytest -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.graph import Graph from dify_graph.nodes.base.node import Node -def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node: +def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: node = MagicMock(spec=Node) node.id = node_id node.node_type = node_type @@ -17,9 +17,9 @@ def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node: def test_graph_builder_creates_linear_graph(): builder = Graph.new() - root = _make_node("root", NodeType.START) - mid = _make_node("mid", NodeType.LLM) - end = _make_node("end", NodeType.END) + root = _make_node("root", BuiltinNodeTypes.START) + mid = _make_node("mid", BuiltinNodeTypes.LLM) + end = _make_node("end", BuiltinNodeTypes.END) graph = builder.add_root(root).add_node(mid).add_node(end).build() diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index b93f18c5bd..75de07bd8b 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -7,7 +7,7 @@ import pytest from core.workflow.node_factory import DifyNodeFactory from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -92,7 +92,7 @@ def test_iteration_root_requires_skip_validation(): ) assert graph.root_node.id == node_id - assert graph.root_node.node_type == NodeType.ITERATION + assert graph.root_node.node_type == BuiltinNodeTypes.ITERATION def test_loop_root_requires_skip_validation(): @@ -115,4 +115,4 @@ def test_loop_root_requires_skip_validation(): ) assert graph.root_node.id == node_id - assert graph.root_node.node_type == NodeType.LOOP + assert graph.root_node.node_type == BuiltinNodeTypes.LOOP diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index 9e9fc2e9ec..e94ad74eb0 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -8,7 +8,7 @@ import pytest from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType +from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes.base.node import Node @@ -18,12 +18,12 @@ from tests.workflow_test_utils import build_test_graph_init_params class _TestNodeData(BaseNodeData): - type: NodeType | str | None = None + type: NodeType | None = None execution_type: NodeExecutionType | str | None = None class _TestNode(Node[_TestNodeData]): - node_type = NodeType.ANSWER + node_type = BuiltinNodeTypes.ANSWER execution_type = NodeExecutionType.EXECUTABLE @classmethod @@ -46,13 +46,8 @@ class _TestNode(Node[_TestNodeData]): ) node_type_value = self.data.get("type") - if isinstance(node_type_value, NodeType): + if isinstance(node_type_value, str): self.node_type = node_type_value - elif isinstance(node_type_value, str): - try: - self.node_type = NodeType(node_type_value) - except ValueError: - pass def _run(self): raise NotImplementedError @@ -112,14 +107,17 @@ def test_graph_initialization_runs_default_validators( ): node_factory, graph_config = graph_init_dependencies graph_config["nodes"] = [ - {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, - {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + { + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, + {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, ] graph_config["edges"] = [ {"source": "start", "target": "answer", "sourceHandle": "success"}, ] - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") assert graph.root_node.id == "start" assert "answer" in graph.nodes @@ -130,14 +128,17 @@ def test_graph_validation_fails_for_unknown_edge_targets( ) -> None: node_factory, graph_config = graph_init_dependencies graph_config["nodes"] = [ - {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, ] graph_config["edges"] = [ {"source": "start", "target": "missing", "sourceHandle": "success"}, ] with pytest.raises(GraphValidationError) as exc: - Graph.init(graph_config=graph_config, node_factory=node_factory) + Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) @@ -147,11 +148,14 @@ def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( ) -> None: node_factory, graph_config = graph_init_dependencies graph_config["nodes"] = [ - {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, { "id": "branch", "data": { - "type": NodeType.IF_ELSE, + "type": BuiltinNodeTypes.IF_ELSE, "title": "Branch", "error_strategy": ErrorStrategy.FAIL_BRANCH, }, @@ -161,30 +165,11 @@ def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( {"source": "start", "target": "branch", "sourceHandle": "success"}, ] - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH -def test_graph_validation_blocks_start_and_trigger_coexistence( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, - { - "id": "trigger", - "data": {"type": NodeType.TRIGGER_WEBHOOK, "title": "Webhook", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [] - - with pytest.raises(GraphValidationError) as exc_info: - Graph.init(graph_config=graph_config, node_factory=node_factory) - - assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues) - - def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], ) -> None: @@ -192,9 +177,9 @@ def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( graph_config["nodes"] = [ { "id": "start", - "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, }, - {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, { "id": "note", "type": "custom-note", @@ -211,8 +196,24 @@ def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( {"source": "start", "target": "answer", "sourceHandle": "success"}, ] - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") assert graph.root_node.id == "start" assert "answer" in graph.nodes assert "note" not in graph.nodes + + +def test_graph_init_fails_for_unknown_root_node_id( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + { + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, + ] + graph_config["edges"] = [] + + with pytest.raises(ValueError, match="Root node id missing not found in the graph"): + Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="missing") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index 2b926d754c..6f821ba799 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -3,7 +3,7 @@ from __future__ import annotations from dify_graph.entities.base_node_data import RetryConfig -from dify_graph.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine.domain.graph_execution import GraphExecution from dify_graph.graph_engine.event_management.event_handlers import EventHandler @@ -73,7 +73,7 @@ def test_retry_does_not_emit_additional_start_event() -> None: handler, event_manager, graph_execution = _build_event_handler(node_id) execution_id = "exec-1" - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE start_time = naive_utc_now() start_event = NodeRunStartedEvent( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 3d8de0a00d..9e7b3654b7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -10,7 +10,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes @pytest.fixture @@ -44,7 +44,7 @@ def mock_start_node(): node.id = "test-start-node-id" node.title = "Start Node" node.execution_id = "test-start-execution-id" - node.node_type = NodeType.START + node.node_type = BuiltinNodeTypes.START return node @@ -55,7 +55,7 @@ def mock_llm_node(): node.id = "test-llm-node-id" node.title = "LLM Node" node.execution_id = "test-llm-execution-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM return node @@ -69,7 +69,7 @@ def mock_tool_node(): node.id = "test-tool-node-id" node.title = "Test Tool Node" node.execution_id = "test-tool-execution-id" - node.node_type = NodeType.TOOL + node.node_type = BuiltinNodeTypes.TOOL tool_data = ToolNodeData( title="Test Tool Node", @@ -108,7 +108,7 @@ def mock_retrieval_node(): node.id = "test-retrieval-node-id" node.title = "Retrieval Node" node.execution_id = "test-retrieval-execution-id" - node.node_type = NodeType.KNOWLEDGE_RETRIEVAL + node.node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL return node @@ -130,7 +130,7 @@ def mock_result_event(): return NodeRunSucceededEvent( id="test-execution-id", node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.now(), node_run_result=node_run_result, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 819fd67f9d..2a36f712fd 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_engine.entities.commands import CommandType from dify_graph.graph_events.node import NodeRunSucceededEvent from dify_graph.model_runtime.entities.llm_entities import LLMUsage @@ -15,7 +15,7 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="execution-id", node_id="llm-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.now(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -30,7 +30,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node = MagicMock() node.id = "llm-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() @@ -51,7 +51,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node = MagicMock() node.id = "question-classifier-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.QUESTION_CLASSIFIER + node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() @@ -72,7 +72,7 @@ def test_non_llm_node_is_ignored() -> None: node = MagicMock() node.id = "start-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.START + node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" node.require_dify_context.return_value.tenant_id = "tenant-id" node._model_instance = object() @@ -89,7 +89,7 @@ def test_quota_error_is_handled_in_layer() -> None: node = MagicMock() node.id = "llm-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() @@ -111,7 +111,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() @@ -140,7 +140,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -166,7 +166,7 @@ def test_quota_precheck_passes_without_abort() -> None: node = MagicMock() node.id = "llm-node-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index b4a7cec494..478a2b592e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -16,7 +16,7 @@ import pytest from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -29,7 +29,7 @@ class TestObservabilityLayerInitialization: layer = ObservabilityLayer() assert not layer._is_disabled assert layer._tracer is not None - assert NodeType.TOOL in layer._parsers + assert BuiltinNodeTypes.TOOL in layer._parsers assert layer._default_parser is not None @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) @@ -39,7 +39,7 @@ class TestObservabilityLayerInitialization: layer = ObservabilityLayer() assert not layer._is_disabled assert layer._tracer is not None - assert NodeType.TOOL in layer._parsers + assert BuiltinNodeTypes.TOOL in layer._parsers assert layer._default_parser is not None @@ -117,7 +117,7 @@ class TestObservabilityLayerParserIntegration: attrs = spans[0].attributes assert attrs["node.id"] == mock_start_node.id assert attrs["node.execution_id"] == mock_start_node.execution_id - assert attrs["node.type"] == mock_start_node.node_type.value + assert attrs["node.type"] == mock_start_node.node_type @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index 50d14ff48f..548c10ce8d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -6,7 +6,7 @@ import queue from unittest import mock from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_engine.event_management.event_handlers import EventHandler from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator @@ -26,7 +26,7 @@ def test_dispatcher_should_consume_remains_events_after_pause(): GraphNodeEventBase( id="test", node_id="test", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, ) ) event_handler = mock.Mock(spec=EventHandler) @@ -107,7 +107,7 @@ def _make_started_event() -> NodeRunStartedEvent: return NodeRunStartedEvent( id="start-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_title="Test Node", start_at=naive_utc_now(), ) @@ -117,7 +117,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="success-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_title="Test Node", start_at=naive_utc_now(), node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), @@ -151,20 +151,20 @@ def test_dispatcher_drain_event_queue(): NodeRunStartedEvent( id="start-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_title="Code", start_at=naive_utc_now(), ), NodeRunPauseRequestedEvent( id="pause-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, reason=SchedulingPause(message="test pause"), ), NodeRunSucceededEvent( id="success-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, start_at=naive_utc_now(), node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), ), diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index f886ae1c2b..fc0d22f739 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -7,7 +7,7 @@ for workflows containing nodes that require third-party services. import pytest -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig @@ -227,23 +227,23 @@ def test_mock_factory_node_type_detection(): ) # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(NodeType.LLM) - assert factory.should_mock_node(NodeType.AGENT) - assert factory.should_mock_node(NodeType.TOOL) - assert factory.should_mock_node(NodeType.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(NodeType.HTTP_REQUEST) - assert factory.should_mock_node(NodeType.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(NodeType.DOCUMENT_EXTRACTOR) + assert factory.should_mock_node(BuiltinNodeTypes.LLM) + assert factory.should_mock_node(BuiltinNodeTypes.AGENT) + assert factory.should_mock_node(BuiltinNodeTypes.TOOL) + assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) + assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) + assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) + assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(NodeType.CODE) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.CODE) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Test that non-service nodes are not mocked - assert not factory.should_mock_node(NodeType.START) - assert not factory.should_mock_node(NodeType.END) - assert not factory.should_mock_node(NodeType.IF_ELSE) - assert not factory.should_mock_node(NodeType.VARIABLE_AGGREGATOR) + assert not factory.should_mock_node(BuiltinNodeTypes.START) + assert not factory.should_mock_node(BuiltinNodeTypes.END) + assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) + assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) def test_custom_mock_handler(): @@ -341,15 +341,15 @@ def test_register_custom_mock_node(): ) # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Unregister mock - factory.unregister_mock_node_type(NodeType.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) + assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Re-register custom mock - factory.register_mock_node_type(NodeType.TEMPLATE_TRANSFORM, MockTemplateTransformNode) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, MockTemplateTransformNode) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) def test_default_config_by_node_type(): @@ -358,7 +358,7 @@ def test_default_config_by_node_type(): # Set default config for all LLM nodes mock_config.set_default_config( - NodeType.LLM, + BuiltinNodeTypes.LLM, { "default_response": "Default LLM response for all nodes", "temperature": 0.7, @@ -367,23 +367,23 @@ def test_default_config_by_node_type(): # Set default config for all HTTP nodes mock_config.set_default_config( - NodeType.HTTP_REQUEST, + BuiltinNodeTypes.HTTP_REQUEST, { "default_status": 200, "default_timeout": 30, }, ) - llm_config = mock_config.get_default_config(NodeType.LLM) + llm_config = mock_config.get_default_config(BuiltinNodeTypes.LLM) assert llm_config["default_response"] == "Default LLM response for all nodes" assert llm_config["temperature"] == 0.7 - http_config = mock_config.get_default_config(NodeType.HTTP_REQUEST) + http_config = mock_config.get_default_config(BuiltinNodeTypes.HTTP_REQUEST) assert http_config["default_status"] == 200 assert http_config["default_timeout"] == 30 # Non-configured node type should return empty dict - tool_config = mock_config.get_default_config(NodeType.TOOL) + tool_config = mock_config.get_default_config(BuiltinNodeTypes.TOOL) assert tool_config == {} diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py index cde99196c8..76bf179f33 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py @@ -6,7 +6,7 @@ This test validates that: - When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) """ -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel from dify_graph.graph_events import ( @@ -74,7 +74,11 @@ def test_streaming_output_with_blocking_equals_one(): # Find indices of first LLM success event and first stream chunk event llm2_start_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM), + ( + i + for i, e in enumerate(events) + if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM + ), -1, ) first_chunk_index = next( @@ -96,16 +100,16 @@ def test_streaming_output_with_blocking_equals_one(): # Check all Template's NodeRunStreamChunkEvent should has same id with Template's NodeRunStartedEvent start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.TEMPLATE_TRANSFORM + e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM ] - template_chunk_events = [e for e in stream_chunk_events if e.node_type == NodeType.TEMPLATE_TRANSFORM] + template_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM] assert len(template_chunk_events) == 1, f"Expected 1 template chunk event, but got {len(template_chunk_events)}" assert all(e.id in [se.id for se in start_events] for e in template_chunk_events), ( "Expected all Template chunk events to have same id with Template's NodeRunStartedEvent" ) # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.END] + end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" @@ -168,7 +172,11 @@ def test_streaming_output_with_blocking_not_equals_one(): # Find indices of first LLM success event and first stream chunk event llm2_start_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM), + ( + i + for i, e in enumerate(events) + if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM + ), -1, ) first_chunk_index = next( @@ -194,15 +202,15 @@ def test_streaming_output_with_blocking_not_equals_one(): assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" # Check all LLM's NodeRunStreamChunkEvent should be from LLM nodes - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.LLM] - llm_chunk_events = [e for e in stream_chunk_events if e.node_type == NodeType.LLM] + start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.LLM] + llm_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.LLM] llm_node_ids = {se.node_id for se in start_events} assert all(e.node_id in llm_node_ids for e in llm_chunk_events), ( "Expected all LLM chunk events to be from LLM nodes" ) # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.END] + end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py index b88c15ea2a..778dad5952 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py @@ -1,7 +1,7 @@ import queue from datetime import datetime -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher from dify_graph.graph_events import NodeRunSucceededEvent from dify_graph.node_events import NodeRunResult @@ -51,7 +51,7 @@ def test_dispatcher_drains_events_when_paused() -> None: event = NodeRunSucceededEvent( id="exec-1", node_id="node-1", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, start_at=datetime.utcnow(), node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py index 805e7dbbce..255784b77d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -6,7 +6,7 @@ import json from collections import deque from unittest.mock import MagicMock -from dify_graph.enums import NodeExecutionType, NodeState, NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState from dify_graph.graph_engine.domain import GraphExecution from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator from dify_graph.graph_engine.response_coordinator.path import Path @@ -101,7 +101,9 @@ def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> No class DummyNode: def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: self.id = node_id - self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM + self.node_type = ( + BuiltinNodeTypes.ANSWER if execution_type == NodeExecutionType.RESPONSE else BuiltinNodeTypes.LLM + ) self.execution_type = execution_type self.state = NodeState.UNKNOWN self.title = node_id @@ -160,7 +162,7 @@ def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> No event = NodeRunStreamChunkEvent( id="exec-1", node_id="response-1", - node_type=NodeType.ANSWER, + node_type=BuiltinNodeTypes.ANSWER, selector=["node-source", "text"], chunk="chunk-1", is_final=False, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index 6041c6ff30..8a4649693d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -11,8 +11,6 @@ from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from dify_graph.enums import NodeType - @dataclass class NodeMockConfig: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 338db9076e..93010eea54 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.node import Node from .test_mock_nodes import ( @@ -61,18 +61,18 @@ class MockNodeFactory(DifyNodeFactory): # Map of node types that should be mocked self._mock_node_types = { - NodeType.LLM: MockLLMNode, - NodeType.AGENT: MockAgentNode, - NodeType.TOOL: MockToolNode, - NodeType.KNOWLEDGE_RETRIEVAL: MockKnowledgeRetrievalNode, - NodeType.HTTP_REQUEST: MockHttpRequestNode, - NodeType.QUESTION_CLASSIFIER: MockQuestionClassifierNode, - NodeType.PARAMETER_EXTRACTOR: MockParameterExtractorNode, - NodeType.DOCUMENT_EXTRACTOR: MockDocumentExtractorNode, - NodeType.ITERATION: MockIterationNode, - NodeType.LOOP: MockLoopNode, - NodeType.TEMPLATE_TRANSFORM: MockTemplateTransformNode, - NodeType.CODE: MockCodeNode, + BuiltinNodeTypes.LLM: MockLLMNode, + BuiltinNodeTypes.AGENT: MockAgentNode, + BuiltinNodeTypes.TOOL: MockToolNode, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: MockKnowledgeRetrievalNode, + BuiltinNodeTypes.HTTP_REQUEST: MockHttpRequestNode, + BuiltinNodeTypes.QUESTION_CLASSIFIER: MockQuestionClassifierNode, + BuiltinNodeTypes.PARAMETER_EXTRACTOR: MockParameterExtractorNode, + BuiltinNodeTypes.DOCUMENT_EXTRACTOR: MockDocumentExtractorNode, + BuiltinNodeTypes.ITERATION: MockIterationNode, + BuiltinNodeTypes.LOOP: MockLoopNode, + BuiltinNodeTypes.TEMPLATE_TRANSFORM: MockTemplateTransformNode, + BuiltinNodeTypes.CODE: MockCodeNode, } def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: @@ -92,7 +92,7 @@ class MockNodeFactory(DifyNodeFactory): # Create mock node instance mock_class = self._mock_node_types[node_type] - if node_type == NodeType.CODE: + if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( id=node_id, config=typed_node_config, @@ -102,7 +102,7 @@ class MockNodeFactory(DifyNodeFactory): code_executor=self._code_executor, code_limits=self._code_limits, ) - elif node_type == NodeType.HTTP_REQUEST: + elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( id=node_id, config=typed_node_config, @@ -114,7 +114,11 @@ class MockNodeFactory(DifyNodeFactory): tool_file_manager_factory=self._http_request_tool_file_manager_factory, file_manager=self._http_request_file_manager, ) - elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}: + elif node_type in { + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + }: mock_instance = mock_class( id=node_id, config=typed_node_config, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 8c8e5977c8..3e4247f33f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -3,7 +3,7 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -42,20 +42,20 @@ def test_mock_factory_registers_iteration_node(): ) # Check that iteration node is registered - assert NodeType.ITERATION in factory._mock_node_types + assert BuiltinNodeTypes.ITERATION in factory._mock_node_types print("✓ Iteration node is registered in MockNodeFactory") # Check that loop node is registered - assert NodeType.LOOP in factory._mock_node_types + assert BuiltinNodeTypes.LOOP in factory._mock_node_types print("✓ Loop node is registered in MockNodeFactory") # Check the class types from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode, MockLoopNode - assert factory._mock_node_types[NodeType.ITERATION] == MockIterationNode + assert factory._mock_node_types[BuiltinNodeTypes.ITERATION] == MockIterationNode print("✓ Iteration node maps to MockIterationNode class") - assert factory._mock_node_types[NodeType.LOOP] == MockLoopNode + assert factory._mock_node_types[BuiltinNodeTypes.LOOP] == MockLoopNode print("✓ Loop node maps to MockLoopNode class") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 9e3574266c..e117f81ff9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -12,13 +12,13 @@ from unittest.mock import MagicMock from core.model_manager import ModelInstance from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.code import CodeNode from dify_graph.nodes.document_extractor import DocumentExtractorNode from dify_graph.nodes.http_request import HttpRequestNode -from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode from dify_graph.nodes.llm import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor import ParameterExtractorNode diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 1550dca402..a8398e8f79 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -7,7 +7,7 @@ to ensure they work correctly with the TableTestRunner. from configs import dify_config from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -550,12 +550,12 @@ class TestMockNodeFactory: ) # Verify that CODE and TEMPLATE_TRANSFORM ARE mocked by default (they require SSRF proxy) - assert factory.should_mock_node(NodeType.CODE) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.CODE) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Verify that other third-party service nodes ARE also mocked by default - assert factory.should_mock_node(NodeType.LLM) - assert factory.should_mock_node(NodeType.AGENT) + assert factory.should_mock_node(BuiltinNodeTypes.LLM) + assert factory.should_mock_node(BuiltinNodeTypes.AGENT) def test_factory_creates_mock_template_transform_node(self): """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" @@ -610,7 +610,7 @@ class TestMockNodeFactory: # Verify the correct mock type was created assert isinstance(node, MockTemplateTransformNode) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) def test_factory_creates_mock_code_node(self): """Test that MockNodeFactory creates MockCodeNode for code type.""" @@ -667,4 +667,4 @@ class TestMockNodeFactory: # Verify the correct mock type was created assert isinstance(node, MockCodeNode) - assert factory.should_mock_node(NodeType.CODE) + assert factory.should_mock_node(BuiltinNodeTypes.CODE) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 693cdf9276..5b35b3310a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -5,7 +5,7 @@ Simple test to validate the auto-mock system without external dependencies. import sys from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -64,8 +64,8 @@ def test_mock_config_operations(): assert error_config.error == "Test error" # Test default configs by node type - config.set_default_config(NodeType.LLM, {"temperature": 0.7}) - llm_config = config.get_default_config(NodeType.LLM) + config.set_default_config(BuiltinNodeTypes.LLM, {"temperature": 0.7}) + llm_config = config.get_default_config(BuiltinNodeTypes.LLM) assert llm_config == {"temperature": 0.7} print("✓ MockConfig operations test passed") @@ -130,23 +130,23 @@ def test_mock_factory_detection(): ) # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(NodeType.LLM) - assert factory.should_mock_node(NodeType.AGENT) - assert factory.should_mock_node(NodeType.TOOL) - assert factory.should_mock_node(NodeType.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(NodeType.HTTP_REQUEST) - assert factory.should_mock_node(NodeType.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(NodeType.DOCUMENT_EXTRACTOR) + assert factory.should_mock_node(BuiltinNodeTypes.LLM) + assert factory.should_mock_node(BuiltinNodeTypes.AGENT) + assert factory.should_mock_node(BuiltinNodeTypes.TOOL) + assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) + assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) + assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) + assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(NodeType.CODE) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.CODE) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Test that non-service nodes are not mocked - assert not factory.should_mock_node(NodeType.START) - assert not factory.should_mock_node(NodeType.END) - assert not factory.should_mock_node(NodeType.IF_ELSE) - assert not factory.should_mock_node(NodeType.VARIABLE_AGGREGATOR) + assert not factory.should_mock_node(BuiltinNodeTypes.START) + assert not factory.should_mock_node(BuiltinNodeTypes.END) + assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) + assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) print("✓ MockNodeFactory detection test passed") @@ -186,18 +186,18 @@ def test_mock_factory_registration(): ) # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Unregister mock - factory.unregister_mock_node_type(NodeType.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) + assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Register custom mock (using a dummy class for testing) class DummyMockNode: pass - factory.register_mock_node_type(NodeType.TEMPLATE_TRANSFORM, DummyMockNode) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, DummyMockNode) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) print("✓ MockNodeFactory registration test passed") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index 0ac9d6618d..b954a4faac 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -14,8 +14,8 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel @@ -118,7 +118,11 @@ def test_parallel_streaming_workflow(): with patch.object( DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True ): - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=get_default_root_node_id(graph_config), + ) # Create the graph engine engine = GraphEngine( @@ -164,7 +168,9 @@ def test_parallel_streaming_workflow(): stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] # Get Answer node start event - answer_start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.ANSWER] + answer_start_events = [ + e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.ANSWER + ] assert len(answer_start_events) == 1, f"Expected 1 Answer node start event, got {len(answer_start_events)}" answer_start_event = answer_start_events[0] @@ -211,7 +217,9 @@ def test_parallel_streaming_workflow(): # Get LLM completion events llm_completed_events = [ - (i, e) for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM + (i, e) + for i, e in enumerate(events) + if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM ] # Check LLM completion order - in the current implementation, LLMs run sequentially @@ -263,7 +271,7 @@ def test_parallel_streaming_workflow(): # According to Answer node configuration: '{{#1754339725656.text#}}{{#1754339718571.text#}}' # This means LLM 2 output should come first, then LLM 1 output answer_complete_events = [ - e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.ANSWER + e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.ANSWER ] assert len(answer_complete_events) == 1, f"Expected 1 Answer completion event, got {len(answer_complete_events)}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py new file mode 100644 index 0000000000..198e133454 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py @@ -0,0 +1,71 @@ +"""Unit tests for response session creation.""" + +from __future__ import annotations + +import pytest + +import dify_graph.graph_engine.response_coordinator.session as response_session_module +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType +from dify_graph.graph_engine.response_coordinator import RESPONSE_SESSION_NODE_TYPES +from dify_graph.graph_engine.response_coordinator.session import ResponseSession +from dify_graph.nodes.base.template import Template, TextSegment + + +class DummyResponseNode: + """Minimal response-capable node for session tests.""" + + def __init__(self, *, node_id: str, node_type: NodeType, template: Template) -> None: + self.id = node_id + self.node_type = node_type + self.execution_type = NodeExecutionType.RESPONSE + self.state = NodeState.UNKNOWN + self._template = template + + def get_streaming_template(self) -> Template: + return self._template + + +class DummyNodeWithoutStreamingTemplate: + """Minimal node that violates the response-session contract.""" + + def __init__(self, *, node_id: str, node_type: NodeType) -> None: + self.id = node_id + self.node_type = node_type + self.execution_type = NodeExecutionType.RESPONSE + self.state = NodeState.UNKNOWN + + +def test_response_session_from_node_rejects_node_types_outside_allowlist() -> None: + """Unsupported node types are rejected even if they expose a template.""" + node = DummyResponseNode( + node_id="llm-node", + node_type=BuiltinNodeTypes.LLM, + template=Template(segments=[TextSegment(text="hello")]), + ) + + with pytest.raises(TypeError, match="RESPONSE_SESSION_NODE_TYPES"): + ResponseSession.from_node(node) + + +def test_response_session_from_node_supports_downstream_allowlist_extension(monkeypatch) -> None: + """Downstream applications can extend the supported node-type list.""" + node = DummyResponseNode( + node_id="llm-node", + node_type=BuiltinNodeTypes.LLM, + template=Template(segments=[TextSegment(text="hello")]), + ) + extended_node_types = [*RESPONSE_SESSION_NODE_TYPES, BuiltinNodeTypes.LLM] + monkeypatch.setattr(response_session_module, "RESPONSE_SESSION_NODE_TYPES", extended_node_types) + + session = ResponseSession.from_node(node) + + assert session.node_id == "llm-node" + assert session.template.segments == [TextSegment(text="hello")] + + +def test_response_session_from_node_requires_streaming_template_method() -> None: + """Allowed node types still need to implement the streaming-template contract.""" + node = DummyNodeWithoutStreamingTemplate(node_id="answer-node", node_type=BuiltinNodeTypes.ANSWER) + + with pytest.raises(TypeError, match="get_streaming_template"): + ResponseSession.from_node(node) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 767a8f60ce..ab8fb346b8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -21,7 +21,7 @@ from typing import Any, cast from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig @@ -257,7 +257,11 @@ class WorkflowRunner: else: node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=get_default_root_node_id(graph_config), + ) return graph, graph_runtime_state diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index f0d80af1ed..fd563d1be2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -64,7 +64,7 @@ def test_execute_answer(): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "answer", diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 3fb775f934..81d3f5be9c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,14 +1,12 @@ import pytest +from core.workflow.node_factory import get_node_type_classes_mapping from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.node import Node -# Ensures that all node classes are imported. -from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING - -# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed. -_ = NODE_TYPE_CLASSES_MAPPING +# Ensures that all production node classes are imported and registered. +_ = get_node_type_classes_mapping() class _TestNodeData(BaseNodeData): @@ -43,7 +41,7 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined node_type = cls.node_type node_version = cls.version() - assert isinstance(cls.node_type, NodeType) + assert isinstance(cls.node_type, str) assert isinstance(node_version, str) node_type_and_version = (node_type, node_version) assert node_type_and_version not in type_version_set, ( @@ -56,7 +54,7 @@ def test_extract_node_data_type_from_generic_extracts_type(): """When a class inherits from Node[T], it should extract T.""" class _ConcreteNode(Node[_TestNodeData]): - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE @staticmethod def version() -> str: @@ -108,7 +106,7 @@ def test_init_subclass_rejects_explicit_node_data_type_without_generic(): class _ExplicitNode(Node): _node_data_type = _TestNodeData - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE @staticmethod def version() -> str: @@ -119,7 +117,7 @@ def test_init_subclass_sets_node_data_type_from_generic(): """Verify that __init_subclass__ sets _node_data_type from the generic parameter.""" class _AutoNode(Node[_TestNodeData]): - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE @staticmethod def version() -> str: @@ -132,13 +130,13 @@ def test_validate_node_data_uses_declared_node_data_type(): """Public validation should hydrate the subclass-declared node data model.""" class _AutoNode(Node[_TestNodeData]): - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE @staticmethod def version() -> str: return "1" - base_node_data = BaseNodeData.model_validate({"type": NodeType.CODE, "title": "Test"}) + base_node_data = BaseNodeData.model_validate({"type": BuiltinNodeTypes.CODE, "title": "Test"}) validated = _AutoNode.validate_node_data(base_node_data) diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index 86d326aead..972a945ca0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,8 +1,9 @@ import types from collections.abc import Mapping +from core.workflow.node_factory import get_node_type_classes_mapping from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) @@ -16,11 +17,11 @@ from dify_graph.nodes.variable_assigner.v2.node import ( def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act - mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + mapping: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping() # Assert basic presence - assert NodeType.VARIABLE_ASSIGNER in mapping - va_versions = mapping[NodeType.VARIABLE_ASSIGNER] + assert BuiltinNodeTypes.VARIABLE_ASSIGNER in mapping + va_versions = mapping[BuiltinNodeTypes.VARIABLE_ASSIGNER] # Both concrete versions must be present assert va_versions.get("1") is VariableAssignerV1 @@ -34,7 +35,7 @@ def test_latest_prefers_highest_numeric_version(): # Arrange: define two ephemeral subclasses with numeric versions under a NodeType # that has no concrete implementations in production to avoid interference. class _Version1(Node[BaseNodeData]): # type: ignore[misc] - node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR + node_type = BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR def init_node_data(self, data): pass @@ -73,11 +74,11 @@ def test_latest_prefers_highest_numeric_version(): return "version2" # Act: build a fresh mapping (it should now see our ephemeral subclasses) - mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + mapping: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping() # Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version - assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping - legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR] + assert BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR in mapping + legacy_versions = mapping[BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR] assert legacy_versions.get("1") is _Version1 assert legacy_versions.get("2") is _Version2 diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index db096b1aed..859115ceb3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,7 +1,7 @@ +from core.workflow.nodes.datasource.datasource_node import DatasourceNode from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.datasource.datasource_node import DatasourceNode class _VarSeg: @@ -74,6 +74,8 @@ def test_datasource_node_delegates_to_manager_stream(mocker): def get_upload_file_by_id(cls, **_): raise AssertionError("not called") + mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) + node = DatasourceNode( id="n", config={ @@ -90,7 +92,6 @@ def test_datasource_node_delegates_to_manager_stream(mocker): }, graph_init_params=gp, graph_runtime_state=gs, - datasource_manager=_Mgr, ) evts = list(node._run()) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 1fea19e795..b0ed47158d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, @@ -47,7 +47,7 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# config = { "id": "node-1", - "type": NodeType.HUMAN_INPUT.value, + "type": BuiltinNodeTypes.HUMAN_INPUT, "data": { "title": "Human Input", "form_content": form_content, @@ -111,7 +111,7 @@ def _build_timeout_node() -> HumanInputNode: config = { "id": "node-1", - "type": NodeType.HUMAN_INPUT.value, + "type": BuiltinNodeTypes.HUMAN_INPUT, "data": { "title": "Human Input", "form_content": "Please enter your name:\n\n{{#$output.name#}}", diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py index 490df52533..fdf5f4d1f8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py @@ -1,5 +1,5 @@ from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from dify_graph.nodes.iteration.exc import ( InvalidIteratorValueError, @@ -91,7 +91,7 @@ class TestIterationNodeClassAttributes: def test_node_type(self): """Test IterationNode node_type attribute.""" - assert IterationNode.node_type == NodeType.ITERATION + assert IterationNode.node_type == BuiltinNodeTypes.ITERATION def test_version(self): """Test IterationNode version method.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 8116fc8b3c..33f7ace5ab 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -5,12 +5,16 @@ from unittest.mock import Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData +from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError +from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode +from core.workflow.nodes.knowledge_index.protocols import ( + IndexProcessorProtocol, + Preview, + PreviewItem, + SummaryIndexServiceProtocol, +) from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.nodes.knowledge_index.entities import KnowledgeIndexNodeData -from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError -from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode -from dify_graph.repositories.index_processor_protocol import IndexProcessorProtocol, Preview, PreviewItem -from dify_graph.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import StringSegment @@ -45,16 +49,24 @@ def mock_graph_runtime_state(): @pytest.fixture -def mock_index_processor(): +def mock_index_processor(mocker): """Create mock IndexProcessorProtocol.""" mock_processor = Mock(spec=IndexProcessorProtocol) + mocker.patch( + "core.workflow.nodes.knowledge_index.knowledge_index_node.IndexProcessor", + return_value=mock_processor, + ) return mock_processor @pytest.fixture -def mock_summary_index_service(): +def mock_summary_index_service(mocker): """Create mock SummaryIndexServiceProtocol.""" mock_service = Mock(spec=SummaryIndexServiceProtocol) + mocker.patch( + "core.workflow.nodes.knowledge_index.knowledge_index_node.SummaryIndex", + return_value=mock_service, + ) return mock_service @@ -107,8 +119,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Assert @@ -137,8 +147,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act & Assert @@ -172,8 +180,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act & Assert @@ -210,8 +216,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act @@ -269,8 +273,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act @@ -334,8 +336,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act @@ -387,8 +387,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act @@ -446,8 +444,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act @@ -506,8 +502,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act @@ -546,8 +540,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act @@ -595,8 +587,6 @@ class TestKnowledgeIndexNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act @@ -637,8 +627,6 @@ class TestInvokeKnowledgeIndex: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - index_processor=mock_index_processor, - summary_index_service=mock_summary_index_service, ) # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index b7a7a9c938..99997db6b2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -5,9 +5,7 @@ from unittest.mock import Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.nodes.knowledge_retrieval.entities import ( +from core.workflow.nodes.knowledge_retrieval.entities import ( Condition, KnowledgeRetrievalNodeData, MetadataFilteringCondition, @@ -15,9 +13,11 @@ from dify_graph.nodes.knowledge_retrieval.entities import ( RerankingModelConfig, SingleRetrievalConfig, ) -from dify_graph.nodes.knowledge_retrieval.exc import RateLimitExceededError -from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from dify_graph.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source +from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import StringSegment @@ -52,11 +52,15 @@ def mock_graph_runtime_state(): @pytest.fixture -def mock_rag_retrieval(): +def mock_rag_retrieval(mocker): """Create mock RAGRetrievalProtocol.""" mock_retrieval = Mock(spec=RAGRetrievalProtocol) mock_retrieval.knowledge_retrieval.return_value = [] mock_retrieval.llm_usage = LLMUsage.empty_usage() + mocker.patch( + "core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node.DatasetRetrieval", + return_value=mock_retrieval, + ) return mock_retrieval @@ -106,7 +110,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Assert @@ -136,7 +139,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -196,7 +198,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -241,7 +242,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -278,7 +278,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -314,7 +313,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -356,7 +354,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -396,7 +393,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -478,7 +474,6 @@ class TestFetchDatasetRetriever: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -516,7 +511,6 @@ class TestFetchDatasetRetriever: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -572,7 +566,6 @@ class TestFetchDatasetRetriever: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -621,7 +614,6 @@ class TestFetchDatasetRetriever: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) conditions = MetadataFilteringCondition( @@ -683,7 +675,6 @@ class TestFetchDatasetRetriever: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) mock_rag_retrieval.knowledge_retrieval.return_value = [] diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 25760ba352..d71e0921c1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -4,7 +4,7 @@ import pytest from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.nodes.list_operator.node import ListOperatorNode from dify_graph.runtime import GraphRuntimeState from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment @@ -71,7 +71,7 @@ class TestListOperatorNode: graph_runtime_state=mock_graph_runtime_state, ) - assert node.node_type == NodeType.LIST_OPERATOR + assert node.node_type == BuiltinNodeTypes.LIST_OPERATOR assert node._node_data.title == "List Operator" def test_version(self): diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 6831626f58..332a8761f9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode @@ -65,7 +65,7 @@ class TestTemplateTransformNode: template_renderer=mock_renderer, ) - assert node.node_type == NodeType.TEMPLATE_TRANSFORM + assert node.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM assert node._node_data.title == "Template Transform" assert len(node._node_data.variables) == 2 assert node._node_data.template == "Hello {{ name }}, you are {{ age }} years old!" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 0d81e7762b..2b0205fb7b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -6,7 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -18,7 +18,7 @@ class _SampleNodeData(BaseNodeData): class _SampleNode(Node[_SampleNodeData]): - node_type = NodeType.ANSWER + node_type = BuiltinNodeTypes.ANSWER @classmethod def version(cls) -> str: @@ -46,7 +46,7 @@ def _build_node_config() -> NodeConfigDict: { "id": "node-1", "data": { - "type": NodeType.ANSWER.value, + "type": BuiltinNodeTypes.ANSWER, "title": "Sample", "foo": "bar", }, @@ -105,7 +105,7 @@ def test_missing_generic_argument_raises_type_error(): with pytest.raises(TypeError): class _InvalidNode(Node): # type: ignore[type-abstract] - node_type = NodeType.ANSWER + node_type = BuiltinNodeTypes.ANSWER @classmethod def version(cls) -> str: @@ -118,7 +118,7 @@ def test_missing_generic_argument_raises_type_error(): def test_base_node_data_keeps_dict_style_access_compatibility(): node_data = _SampleNodeData.model_validate( { - "type": NodeType.ANSWER.value, + "type": BuiltinNodeTypes.ANSWER, "title": "Sample", "foo": "bar", } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 13275d4be6..40754974c1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -7,7 +7,7 @@ from docx.oxml.text.paragraph import CT_P from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod from dify_graph.node_events import NodeRunResult from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData @@ -250,7 +250,7 @@ def test_extract_text_from_docx(mock_document): def test_node_type(document_extractor_node): - assert document_extractor_node.node_type == NodeType.DOCUMENT_EXTRACTOR + assert document_extractor_node.node_type == BuiltinNodeTypes.DOCUMENT_EXTRACTOR @patch("pandas.ExcelFile") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 041bd66d03..c746a945fe 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -60,7 +60,7 @@ def test_execute_if_else_result_true(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "if-else", @@ -154,7 +154,7 @@ def test_execute_if_else_result_false(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "if-else", @@ -328,7 +328,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_data = { "title": "Boolean Test", @@ -382,7 +382,7 @@ def test_execute_if_else_boolean_false_conditions(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_data = { "title": "Boolean False Test", @@ -450,7 +450,7 @@ def test_execute_if_else_boolean_cases_structure(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_data = { "title": "Boolean Cases Test", diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 2cd3a38fa6..e69c05dc0b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -88,7 +88,7 @@ def test_overwrite_string_variable(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -188,7 +188,7 @@ def test_append_variable_to_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -277,7 +277,7 @@ def test_clear_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 5b285c2681..6874f3fef1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -118,7 +118,7 @@ def test_remove_first_from_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -206,7 +206,7 @@ def test_remove_last_from_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -291,7 +291,7 @@ def test_remove_first_from_empty_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -376,7 +376,7 @@ def test_remove_last_from_empty_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 61b18566b0..6be5bb23e8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from dify_graph.nodes.trigger_webhook.entities import ( +from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, WebhookBodyParameter, diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index a821e361c5..ddf1af5a59 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,12 +1,12 @@ import pytest -from dify_graph.entities.exc import BaseNodeError -from dify_graph.nodes.trigger_webhook.exc import ( +from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, WebhookNodeError, WebhookNotFoundError, WebhookTimeoutError, ) +from dify_graph.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): @@ -149,7 +149,7 @@ def test_webhook_error_attributes(): assert WebhookConfigError.__name__ == "WebhookConfigError" # Test that all error classes have proper __module__ - expected_module = "dify_graph.nodes.trigger_webhook.exc" + expected_module = "core.workflow.nodes.trigger_webhook.exc" assert WebhookNodeError.__module__ == expected_module assert WebhookTimeoutError.__module__ == expected_module assert WebhookNotFoundError.__module__ == expected_module diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index c750e74182..78dd7ce0f3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -9,15 +9,15 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.nodes.trigger_webhook.entities import ( +from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, WebhookBodyParameter, WebhookData, ) -from dify_graph.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool from dify_graph.system_variable import SystemVariable @@ -130,8 +130,8 @@ def test_webhook_node_file_conversion_to_file_variable(): # Mock the file factory and variable factory with ( patch("factories.file_factory.build_from_mapping") as mock_file_factory, - patch("dify_graph.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, - patch("dify_graph.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): # Setup mocks mock_file_obj = Mock() @@ -322,8 +322,8 @@ def test_webhook_node_file_conversion_mixed_parameters(): with ( patch("factories.file_factory.build_from_mapping") as mock_file_factory, - patch("dify_graph.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, - patch("dify_graph.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): # Setup mocks for file mock_file_obj = Mock() @@ -390,8 +390,8 @@ def test_webhook_node_different_file_types(): with ( patch("factories.file_factory.build_from_mapping") as mock_file_factory, - patch("dify_graph.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, - patch("dify_graph.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): # Setup mocks for all files mock_file_objs = [Mock() for _ in range(3)] diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index df13bbb92f..139f65d6c3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -3,17 +3,18 @@ from unittest.mock import patch import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.trigger_webhook.entities import ( +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, WebhookBodyParameter, WebhookData, WebhookParameter, ) -from dify_graph.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool from dify_graph.system_variable import SystemVariable @@ -82,7 +83,7 @@ def test_webhook_node_basic_initialization(): node = create_webhook_node(data, variable_pool) - assert node.node_type.value == "trigger-webhook" + assert node.node_type == TRIGGER_WEBHOOK_NODE_TYPE assert node.version() == "1" assert node._get_title() == "Test Webhook" assert node._node_data.method == Method.POST diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 934e29546c..ab46126ca6 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -5,9 +5,10 @@ import pytest from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.workflow import node_factory +from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey from dify_graph.nodes.code.entities import CodeLanguage from dify_graph.variables.segments import StringSegment @@ -145,7 +146,6 @@ class TestDifyNodeFactoryInit: graph_runtime_state = sentinel.graph_runtime_state dify_context = SimpleNamespace(tenant_id="tenant-id") template_renderer = sentinel.template_renderer - rag_retrieval = sentinel.rag_retrieval unstructured_api_config = sentinel.unstructured_api_config http_request_config = sentinel.http_request_config credentials_provider = sentinel.credentials_provider @@ -162,7 +162,6 @@ class TestDifyNodeFactoryInit: "CodeExecutorJinja2TemplateRenderer", return_value=template_renderer, ) as renderer_factory, - patch.object(node_factory, "DatasetRetrieval", return_value=rag_retrieval), patch.object( node_factory, "UnstructuredApiConfig", @@ -192,7 +191,6 @@ class TestDifyNodeFactoryInit: assert factory.graph_runtime_state is graph_runtime_state assert factory._dify_context is dify_context assert factory._template_renderer is template_renderer - assert factory._rag_retrieval is rag_retrieval assert factory._document_extractor_unstructured_api_config is unstructured_api_config assert factory._http_request_config is http_request_config assert factory._llm_credentials_provider is credentials_provider @@ -248,7 +246,6 @@ class TestDifyNodeFactoryCreateNode: factory._http_request_http_client = sentinel.http_client factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory factory._http_request_file_manager = sentinel.file_manager - factory._rag_retrieval = sentinel.rag_retrieval factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config factory._http_request_config = sentinel.http_request_config factory._llm_credentials_provider = sentinel.credentials_provider @@ -256,46 +253,46 @@ class TestDifyNodeFactoryCreateNode: return factory def test_rejects_unknown_node_type(self, factory): - with pytest.raises(ValueError, match="Input should be"): + with pytest.raises(ValueError, match="No class mapping found for node type: missing"): factory.create_node({"id": "node-id", "data": {"type": "missing"}}) def test_rejects_missing_class_mapping(self, monkeypatch, factory): monkeypatch.setattr( - node_factory, - "resolve_workflow_node_class", + factory, + "_resolve_node_class", MagicMock(side_effect=ValueError("No class mapping found for node type: start")), ) with pytest.raises(ValueError, match="No class mapping found for node type: start"): - factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}}) + factory.create_node({"id": "node-id", "data": {"type": BuiltinNodeTypes.START}}) def test_rejects_missing_latest_class(self, monkeypatch, factory): monkeypatch.setattr( - node_factory, - "resolve_workflow_node_class", + factory, + "_resolve_node_class", MagicMock(side_effect=ValueError("No latest version class found for node type: start")), ) with pytest.raises(ValueError, match="No latest version class found for node type: start"): - factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}}) + factory.create_node({"id": "node-id", "data": {"type": BuiltinNodeTypes.START}}) def test_uses_version_specific_class_when_available(self, monkeypatch, factory): matched_node = sentinel.matched_node latest_node_class = MagicMock(return_value=sentinel.latest_node) matched_node_class = MagicMock(return_value=matched_node) monkeypatch.setattr( - node_factory, - "resolve_workflow_node_class", + factory, + "_resolve_node_class", MagicMock(return_value=matched_node_class), ) - result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}}) + result = factory.create_node({"id": "node-id", "data": {"type": BuiltinNodeTypes.START, "version": "9"}}) assert result is matched_node matched_node_class.assert_called_once() kwargs = matched_node_class.call_args.kwargs assert kwargs["id"] == "node-id" - _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9") + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") assert kwargs["graph_init_params"] is sentinel.graph_init_params assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state latest_node_class.assert_not_called() @@ -304,40 +301,40 @@ class TestDifyNodeFactoryCreateNode: latest_node = sentinel.latest_node latest_node_class = MagicMock(return_value=latest_node) monkeypatch.setattr( - node_factory, - "resolve_workflow_node_class", + factory, + "_resolve_node_class", MagicMock(return_value=latest_node_class), ) - result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}}) + result = factory.create_node({"id": "node-id", "data": {"type": BuiltinNodeTypes.START, "version": "9"}}) assert result is latest_node latest_node_class.assert_called_once() kwargs = latest_node_class.call_args.kwargs assert kwargs["id"] == "node-id" - _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9") + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") assert kwargs["graph_init_params"] is sentinel.graph_init_params assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state @pytest.mark.parametrize( ("node_type", "constructor_name"), [ - (NodeType.CODE, "CodeNode"), - (NodeType.TEMPLATE_TRANSFORM, "TemplateTransformNode"), - (NodeType.HTTP_REQUEST, "HttpRequestNode"), - (NodeType.HUMAN_INPUT, "HumanInputNode"), - (NodeType.KNOWLEDGE_INDEX, "KnowledgeIndexNode"), - (NodeType.DATASOURCE, "DatasourceNode"), - (NodeType.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"), - (NodeType.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"), + (BuiltinNodeTypes.CODE, "CodeNode"), + (BuiltinNodeTypes.TEMPLATE_TRANSFORM, "TemplateTransformNode"), + (BuiltinNodeTypes.HTTP_REQUEST, "HttpRequestNode"), + (BuiltinNodeTypes.HUMAN_INPUT, "HumanInputNode"), + (KNOWLEDGE_INDEX_NODE_TYPE, "KnowledgeIndexNode"), + (BuiltinNodeTypes.DATASOURCE, "DatasourceNode"), + (BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"), + (BuiltinNodeTypes.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"), ], ) def test_creates_specialized_nodes(self, monkeypatch, factory, node_type, constructor_name): created_node = object() constructor = MagicMock(name=constructor_name, return_value=created_node) monkeypatch.setattr( - node_factory, - "resolve_workflow_node_class", + factory, + "_resolve_node_class", MagicMock(return_value=constructor), ) @@ -349,13 +346,8 @@ class TestDifyNodeFactoryCreateNode: "HumanInputFormRepositoryImpl", form_repository_impl, ) - elif constructor_name == "KnowledgeIndexNode": - index_processor = sentinel.index_processor - summary_index = sentinel.summary_index - monkeypatch.setattr(node_factory, "IndexProcessor", MagicMock(return_value=index_processor)) - monkeypatch.setattr(node_factory, "SummaryIndex", MagicMock(return_value=summary_index)) - node_config = {"id": "node-id", "data": {"type": node_type.value}} + node_config = {"id": "node-id", "data": {"type": node_type}} result = factory.create_node(node_config) assert result is created_node @@ -379,13 +371,6 @@ class TestDifyNodeFactoryCreateNode: elif constructor_name == "HumanInputNode": assert kwargs["form_repository"] is form_repository form_repository_impl.assert_called_once_with(tenant_id="tenant-id") - elif constructor_name == "KnowledgeIndexNode": - assert kwargs["index_processor"] is index_processor - assert kwargs["summary_index_service"] is summary_index - elif constructor_name == "DatasourceNode": - assert kwargs["datasource_manager"] is node_factory.DatasourceManager - elif constructor_name == "KnowledgeRetrievalNode": - assert kwargs["rag_retrieval"] is sentinel.rag_retrieval elif constructor_name == "DocumentExtractorNode": assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config assert kwargs["http_client"] is sentinel.http_client @@ -393,9 +378,9 @@ class TestDifyNodeFactoryCreateNode: @pytest.mark.parametrize( ("node_type", "constructor_name", "expected_extra_kwargs"), [ - (NodeType.LLM, "LLMNode", {"http_client": sentinel.http_client}), - (NodeType.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}), - (NodeType.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}), + (BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}), + (BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}), + (BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}), ], ) def test_creates_model_backed_nodes( @@ -409,8 +394,8 @@ class TestDifyNodeFactoryCreateNode: created_node = object() constructor = MagicMock(name=constructor_name, return_value=created_node) monkeypatch.setattr( - node_factory, - "resolve_workflow_node_class", + factory, + "_resolve_node_class", MagicMock(return_value=constructor), ) llm_init_kwargs = { @@ -423,7 +408,7 @@ class TestDifyNodeFactoryCreateNode: build_llm_init_kwargs = MagicMock(return_value=llm_init_kwargs) factory._build_llm_compatible_node_init_kwargs = build_llm_init_kwargs - node_config = {"id": "node-id", "data": {"type": node_type.value}} + node_config = {"id": "node-id", "data": {"type": node_type}} result = factory.create_node(node_config) assert result is created_node @@ -432,7 +417,7 @@ class TestDifyNodeFactoryCreateNode: assert helper_kwargs["node_class"] is constructor assert isinstance(helper_kwargs["node_data"], BaseNodeData) assert helper_kwargs["node_data"].type == node_type - assert helper_kwargs["include_http_client"] is (node_type != NodeType.PARAMETER_EXTRACTOR) + assert helper_kwargs["include_http_client"] is (node_type != BuiltinNodeTypes.PARAMETER_EXTRACTOR) constructor_kwargs = constructor.call_args.kwargs assert constructor_kwargs["id"] == "node-id" diff --git a/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py b/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py new file mode 100644 index 0000000000..8de45257ec --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py @@ -0,0 +1,43 @@ +import os +import subprocess +import sys +import textwrap +from pathlib import Path + + +def test_moved_core_nodes_resolve_after_importing_production_entrypoints(): + api_root = Path(__file__).resolve().parents[4] + script = textwrap.dedent( + """ + from core.app.apps import workflow_app_runner + from core.workflow import workflow_entry + from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE + from core.workflow.node_factory import DifyNodeFactory, NODE_TYPE_CLASSES_MAPPING + from dify_graph.enums import BuiltinNodeTypes + from services import workflow_service + from services.rag_pipeline import rag_pipeline + + _ = workflow_entry, workflow_app_runner, workflow_service, rag_pipeline + + expected = ( + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + KNOWLEDGE_INDEX_NODE_TYPE, + BuiltinNodeTypes.DATASOURCE, + ) + + for node_type in expected: + assert node_type in NODE_TYPE_CLASSES_MAPPING, node_type + resolved = DifyNodeFactory._resolve_node_class(node_type=node_type, node_version="1") + assert resolved.__module__.startswith("core.workflow.nodes."), resolved.__module__ + """ + ) + completed = subprocess.run( + [sys.executable, "-c", script], + cwd=api_root, + env=os.environ.copy(), + capture_output=True, + text=True, + check=False, + ) + + assert completed.returncode == 0, completed.stderr or completed.stdout diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 68e42894fc..dc4c7a00c5 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -8,11 +8,12 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow import workflow_entry from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.enums import NodeType from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.file.models import File from dify_graph.graph_events import GraphRunFailedEvent -from dify_graph.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from dify_graph.runtime import ChildGraphNotFoundError @@ -240,7 +241,7 @@ class TestWorkflowEntrySingleStepRun: app_id="app-id", id="workflow-id", graph_dict={"nodes": [], "edges": []}, - get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START), + get_node_config_by_id=lambda _node_id: _build_typed_node_config(BuiltinNodeTypes.START), ) node, generator = workflow_entry.WorkflowEntry.single_step_run( @@ -302,7 +303,7 @@ class TestWorkflowEntrySingleStepRun: app_id="app-id", id="workflow-id", graph_dict={"nodes": [], "edges": []}, - get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.DATASOURCE), + get_node_config_by_id=lambda _node_id: _build_typed_node_config(BuiltinNodeTypes.DATASOURCE), ) node, generator = workflow_entry.WorkflowEntry.single_step_run( @@ -352,7 +353,7 @@ class TestWorkflowEntrySingleStepRun: app_id="app-id", id="workflow-id", graph_dict={"nodes": [], "edges": []}, - get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START), + get_node_config_by_id=lambda _node_id: _build_typed_node_config(BuiltinNodeTypes.START), ) with pytest.raises(WorkflowNodeRunFailedError): @@ -369,7 +370,7 @@ class TestWorkflowEntryHelpers: def test_create_single_node_graph_builds_start_edge(self): graph = workflow_entry.WorkflowEntry._create_single_node_graph( node_id="target-node", - node_data={"type": NodeType.PARAMETER_EXTRACTOR}, + node_data={"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}, node_width=320, node_height=180, ) @@ -390,7 +391,7 @@ class TestWorkflowEntryHelpers: def test_run_free_node_rejects_unsupported_types(self): with pytest.raises(ValueError, match="Node type start not supported"): workflow_entry.WorkflowEntry.run_free_node( - node_data={"type": NodeType.START.value}, + node_data={"type": BuiltinNodeTypes.START}, node_id="node-id", tenant_id="tenant-id", user_id="user-id", @@ -406,7 +407,7 @@ class TestWorkflowEntryHelpers: with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"): workflow_entry.WorkflowEntry.run_free_node( - node_data={"type": NodeType.PARAMETER_EXTRACTOR.value}, + node_data={"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}, node_id="node-id", tenant_id="tenant-id", user_id="user-id", @@ -459,7 +460,7 @@ class TestWorkflowEntryHelpers: ), ): node, generator = workflow_entry.WorkflowEntry.run_free_node( - node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}, + node_data={"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}, node_id="node-id", tenant_id="tenant-id", user_id="user-id", @@ -483,7 +484,7 @@ class TestWorkflowEntryHelpers: graph_init_params.assert_called_once_with( workflow_id="", graph_config=workflow_entry.WorkflowEntry._create_single_node_graph( - "node-id", {"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"} + "node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"} ), run_context={"_dify": "context"}, call_depth=0, @@ -538,7 +539,7 @@ class TestWorkflowEntryHelpers: ): with pytest.raises(WorkflowNodeRunFailedError, match="Node Title run failed: boom"): workflow_entry.WorkflowEntry.run_free_node( - node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}, + node_data={"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}, node_id="node-id", tenant_id="tenant-id", user_id="user-id", diff --git a/api/tests/unit_tests/libs/test_cron_compatibility.py b/api/tests/unit_tests/libs/test_cron_compatibility.py index 61103d7935..6f3a94f6dc 100644 --- a/api/tests/unit_tests/libs/test_cron_compatibility.py +++ b/api/tests/unit_tests/libs/test_cron_compatibility.py @@ -294,7 +294,7 @@ class TestFrontendBackendIntegration(unittest.TestCase): def test_schedule_service_integration(self): """Test integration with ScheduleService patterns.""" - from dify_graph.nodes.trigger_schedule.entities import VisualConfig + from core.workflow.nodes.trigger_schedule.entities import VisualConfig from services.trigger.schedule_service import ScheduleService # Test enhanced syntax through visual config conversion diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index f66f0b657d..4fcef34549 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -15,7 +15,7 @@ from uuid import uuid4 import pytest from dify_graph.enums import ( - NodeType, + BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, ) @@ -471,7 +471,7 @@ class TestNodeExecutionRelationships: workflow_run_id=workflow_run_id, index=1, node_id="start", - node_type=NodeType.START.value, + node_type=BuiltinNodeTypes.START, title="Start Node", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -484,7 +484,7 @@ class TestNodeExecutionRelationships: assert node_execution.workflow_id == workflow_id assert node_execution.workflow_run_id == workflow_run_id assert node_execution.node_id == "start" - assert node_execution.node_type == NodeType.START.value + assert node_execution.node_type == BuiltinNodeTypes.START assert node_execution.index == 1 def test_node_execution_with_predecessor_relationship(self): @@ -503,7 +503,7 @@ class TestNodeExecutionRelationships: index=2, predecessor_node_id=predecessor_node_id, node_id=current_node_id, - node_type=NodeType.LLM.value, + node_type=BuiltinNodeTypes.LLM, title="LLM Node", status=WorkflowNodeExecutionStatus.RUNNING.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -526,7 +526,7 @@ class TestNodeExecutionRelationships: workflow_run_id=None, # Single-step has no workflow run index=1, node_id="llm_test", - node_type=NodeType.LLM.value, + node_type=BuiltinNodeTypes.LLM, title="Test LLM", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -553,7 +553,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=1, node_id="llm_1", - node_type=NodeType.LLM.value, + node_type=BuiltinNodeTypes.LLM, title="LLM Node", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -579,7 +579,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=1, node_id="code_1", - node_type=NodeType.CODE.value, + node_type=BuiltinNodeTypes.CODE, title="Code Node", status=WorkflowNodeExecutionStatus.RUNNING.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -610,7 +610,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=3, node_id="code_1", - node_type=NodeType.CODE.value, + node_type=BuiltinNodeTypes.CODE, title="Code Node", status=WorkflowNodeExecutionStatus.FAILED.value, error=error_message, @@ -641,7 +641,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=1, node_id="llm_1", - node_type=NodeType.LLM.value, + node_type=BuiltinNodeTypes.LLM, title="LLM Node", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -664,7 +664,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=1, node_id="start", - node_type=NodeType.START.value, + node_type=BuiltinNodeTypes.START, title="Start", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -682,12 +682,12 @@ class TestNodeExecutionRelationships: """Test node execution with different node types.""" # Test various node types node_types = [ - (NodeType.START, "Start Node"), - (NodeType.LLM, "LLM Node"), - (NodeType.CODE, "Code Node"), - (NodeType.TOOL, "Tool Node"), - (NodeType.IF_ELSE, "Conditional Node"), - (NodeType.END, "End Node"), + (BuiltinNodeTypes.START, "Start Node"), + (BuiltinNodeTypes.LLM, "LLM Node"), + (BuiltinNodeTypes.CODE, "Code Node"), + (BuiltinNodeTypes.TOOL, "Tool Node"), + (BuiltinNodeTypes.IF_ELSE, "Conditional Node"), + (BuiltinNodeTypes.END, "End Node"), ] for node_type, title in node_types: @@ -699,8 +699,8 @@ class TestNodeExecutionRelationships: triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, workflow_run_id=str(uuid4()), index=1, - node_id=f"{node_type.value}_1", - node_type=node_type.value, + node_id=f"{node_type}_1", + node_type=node_type, title=title, status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -708,7 +708,7 @@ class TestNodeExecutionRelationships: ) # Assert - assert node_execution.node_type == node_type.value + assert node_execution.node_type == node_type assert node_execution.title == title @@ -1004,7 +1004,7 @@ class TestGraphConfigurationValidation: workflow_run_id=str(uuid4()), index=1, node_id="start", - node_type=NodeType.START.value, + node_type=BuiltinNodeTypes.START, title="Start", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -1029,7 +1029,7 @@ class TestGraphConfigurationValidation: workflow_run_id=str(uuid4()), index=1, node_id="start", - node_type=NodeType.START.value, + node_type=BuiltinNodeTypes.START, title="Start", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index 06703b8e38..086d1ac52e 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -17,7 +17,7 @@ from dify_graph.entities import ( WorkflowNodeExecution, ) from dify_graph.enums import ( - NodeType, + BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -230,7 +230,7 @@ def test_to_db_model(repository): index=1, predecessor_node_id="test-predecessor-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", inputs={"input_key": "input_value"}, process_data={"process_key": "process_value"}, @@ -298,7 +298,7 @@ def test_to_domain_model(repository): db_model.predecessor_node_id = "test-predecessor-id" db_model.node_execution_id = "test-node-execution-id" db_model.node_id = "test-node-id" - db_model.node_type = NodeType.START + db_model.node_type = BuiltinNodeTypes.START db_model.title = "Test Node" db_model.inputs = json.dumps(inputs_dict) db_model.process_data = json.dumps(process_data_dict) @@ -324,7 +324,7 @@ def test_to_domain_model(repository): assert domain_model.predecessor_node_id == db_model.predecessor_node_id assert domain_model.node_execution_id == db_model.node_execution_id assert domain_model.node_id == db_model.node_id - assert domain_model.node_type == NodeType(db_model.node_type) + assert domain_model.node_type == db_model.node_type assert domain_model.title == db_model.title assert domain_model.inputs == inputs_dict assert domain_model.process_data == process_data_dict diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py index 95a7751273..e01fb8456f 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py @@ -12,7 +12,7 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom @@ -62,7 +62,7 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData: workflow_id="test-workflow-id", index=1, node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", process_data=process_data, created_at=datetime.now(), diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index 33d26f4bcb..7e82f79860 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -5,7 +5,12 @@ from unittest.mock import MagicMock import pytest import yaml -from dify_graph.enums import NodeType +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) +from dify_graph.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import IconType from services import app_dsl_service @@ -522,7 +527,7 @@ def test_create_or_update_app_creates_workflow_app_and_saves_dependencies(monkey "conversation_variables": [{"y": 2}], "graph": { "nodes": [ - {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["enc-1", "enc-2"]}}, + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["enc-1", "enc-2"]}}, ] }, "features": {}, @@ -671,17 +676,17 @@ def test_append_workflow_export_data_filters_and_overrides(monkeypatch): workflow_dict = { "graph": { "nodes": [ - {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["d1", "d2"]}}, - {"data": {"type": NodeType.TOOL, "credential_id": "secret"}}, + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["d1", "d2"]}}, + {"data": {"type": BuiltinNodeTypes.TOOL, "credential_id": "secret"}}, { "data": { - "type": NodeType.AGENT, + "type": BuiltinNodeTypes.AGENT, "agent_parameters": {"tools": {"value": [{"credential_id": "secret"}]}}, } }, - {"data": {"type": NodeType.TRIGGER_SCHEDULE.value, "config": {"x": 1}}}, - {"data": {"type": NodeType.TRIGGER_WEBHOOK.value, "webhook_url": "x", "webhook_debug_url": "y"}}, - {"data": {"type": NodeType.TRIGGER_PLUGIN.value, "subscription_id": "s"}}, + {"data": {"type": TRIGGER_SCHEDULE_NODE_TYPE, "config": {"x": 1}}}, + {"data": {"type": TRIGGER_WEBHOOK_NODE_TYPE, "webhook_url": "x", "webhook_debug_url": "y"}}, + {"data": {"type": TRIGGER_PLUGIN_NODE_TYPE, "subscription_id": "s"}}, ] } } @@ -809,11 +814,11 @@ def test_extract_dependencies_from_workflow_graph_covers_all_node_types(monkeypa graph = { "nodes": [ - {"data": {"type": NodeType.TOOL}}, - {"data": {"type": NodeType.LLM}}, - {"data": {"type": NodeType.QUESTION_CLASSIFIER}}, - {"data": {"type": NodeType.PARAMETER_EXTRACTOR}}, - {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL}}, + {"data": {"type": BuiltinNodeTypes.TOOL}}, + {"data": {"type": BuiltinNodeTypes.LLM}}, + {"data": {"type": BuiltinNodeTypes.QUESTION_CLASSIFIER}}, + {"data": {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}}, + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}, {"data": {"type": "unknown"}}, ] } @@ -826,7 +831,9 @@ def test_extract_dependencies_from_workflow_graph_handles_exceptions(monkeypatch monkeypatch.setattr( app_dsl_service.ToolNodeData, "model_validate", lambda _d: (_ for _ in ()).throw(ValueError("bad")) ) - deps = AppDslService._extract_dependencies_from_workflow_graph({"nodes": [{"data": {"type": NodeType.TOOL}}]}) + deps = AppDslService._extract_dependencies_from_workflow_graph( + {"nodes": [{"data": {"type": BuiltinNodeTypes.TOOL}}]} + ) assert deps == [] diff --git a/api/tests/unit_tests/services/test_schedule_service.py b/api/tests/unit_tests/services/test_schedule_service.py index 5e3dd157e6..e28965ea2c 100644 --- a/api/tests/unit_tests/services/test_schedule_service.py +++ b/api/tests/unit_tests/services/test_schedule_service.py @@ -5,8 +5,8 @@ from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy.orm import Session -from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig -from dify_graph.nodes.trigger_schedule.exc import ScheduleConfigError +from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig +from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError from events.event_handlers.sync_workflow_schedule_when_app_published import ( sync_schedule_from_workflow, ) @@ -136,7 +136,7 @@ class TestScheduleService(unittest.TestCase): def test_update_schedule_not_found(self): """Test updating a non-existent schedule raises exception.""" - from dify_graph.nodes.trigger_schedule.exc import ScheduleNotFoundError + from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError mock_session = MagicMock(spec=Session) mock_session.get.return_value = None @@ -172,7 +172,7 @@ class TestScheduleService(unittest.TestCase): def test_delete_schedule_not_found(self): """Test deleting a non-existent schedule raises exception.""" - from dify_graph.nodes.trigger_schedule.exc import ScheduleNotFoundError + from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError mock_session = MagicMock(spec=Session) mock_session.get.return_value = None diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 5ce0e6f140..57c0464dc6 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -14,7 +14,7 @@ from unittest.mock import MagicMock, patch import pytest -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from libs.datetime_utils import naive_utc_now from models.model import App, AppMode @@ -134,7 +134,7 @@ class TestWorkflowAssociatedDataFactory: return ( (node["id"], node["data"]) for node in nodes - if node.get("data", {}).get("type") == specific_node_type.value + if node.get("data", {}).get("type") == str(specific_node_type) ) # Return all nodes if no filter specified return ((node["id"], node["data"]) for node in nodes) @@ -179,7 +179,7 @@ class TestWorkflowAssociatedDataFactory: { "id": "start", "data": { - "type": NodeType.START.value, + "type": BuiltinNodeTypes.START, "title": "START", "variables": [], }, @@ -204,7 +204,7 @@ class TestWorkflowAssociatedDataFactory: { "id": "llm-1", "data": { - "type": NodeType.LLM.value, + "type": BuiltinNodeTypes.LLM, "title": "LLM", "model": { "provider": "openai", @@ -1001,12 +1001,12 @@ class TestWorkflowService: Used by the UI to populate the node palette and provide sensible defaults when users add new nodes to their workflow. """ - with patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping: + with patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping: # Mock node class with default config mock_node_class = MagicMock() mock_node_class.get_default_config.return_value = {"type": "llm", "config": {}} - mock_mapping.return_value = {NodeType.LLM: {"latest": mock_node_class}} + mock_mapping.return_value = {BuiltinNodeTypes.LLM: {"latest": mock_node_class}} with patch("services.workflow_service.LATEST_VERSION", "latest"): result = workflow_service.get_default_block_configs() @@ -1025,7 +1025,7 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, + patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch( "services.workflow_service.build_http_request_config", @@ -1037,8 +1037,8 @@ class TestWorkflowService: mock_llm_node_class = MagicMock() mock_llm_node_class.get_default_config.return_value = {"type": "llm", "config": {}} mock_mapping.return_value = { - NodeType.HTTP_REQUEST: {"latest": mock_http_node_class}, - NodeType.LLM: {"latest": mock_llm_node_class}, + BuiltinNodeTypes.HTTP_REQUEST: {"latest": mock_http_node_class}, + BuiltinNodeTypes.LLM: {"latest": mock_llm_node_class}, } result = workflow_service.get_default_block_configs() @@ -1060,7 +1060,7 @@ class TestWorkflowService: This includes default values for all required and optional parameters. """ with ( - patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, + patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), ): # Mock node class with default config @@ -1068,21 +1068,21 @@ class TestWorkflowService: mock_config = {"type": "llm", "config": {"provider": "openai"}} mock_node_class.get_default_config.return_value = mock_config - # Create a mock mapping that includes NodeType.LLM - mock_mapping.return_value = {NodeType.LLM: {"latest": mock_node_class}} + # Create a mock mapping that includes BuiltinNodeTypes.LLM + mock_mapping.return_value = {BuiltinNodeTypes.LLM: {"latest": mock_node_class}} - result = workflow_service.get_default_block_config(NodeType.LLM.value) + result = workflow_service.get_default_block_config(BuiltinNodeTypes.LLM) assert result == mock_config mock_node_class.get_default_config.assert_called_once() def test_get_default_block_config_invalid_node_type(self, workflow_service): """Test get_default_block_config returns empty dict for invalid node type.""" - with patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping: + with patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping: mock_mapping.return_value = {} # Use a valid NodeType but one that's not in the mapping - result = workflow_service.get_default_block_config(NodeType.LLM.value) + result = workflow_service.get_default_block_config(BuiltinNodeTypes.LLM) assert result == {} @@ -1098,7 +1098,7 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, + patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch( "services.workflow_service.build_http_request_config", @@ -1108,9 +1108,9 @@ class TestWorkflowService: mock_node_class = MagicMock() expected = {"type": "http-request", "config": {}} mock_node_class.get_default_config.return_value = expected - mock_mapping.return_value = {NodeType.HTTP_REQUEST: {"latest": mock_node_class}} + mock_mapping.return_value = {BuiltinNodeTypes.HTTP_REQUEST: {"latest": mock_node_class}} - result = workflow_service.get_default_block_config(NodeType.HTTP_REQUEST.value) + result = workflow_service.get_default_block_config(BuiltinNodeTypes.HTTP_REQUEST) assert result == expected mock_build_config.assert_called_once() @@ -1129,17 +1129,17 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, + patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch("services.workflow_service.build_http_request_config") as mock_build_config, ): mock_node_class = MagicMock() expected = {"type": "http-request", "config": {}} mock_node_class.get_default_config.return_value = expected - mock_mapping.return_value = {NodeType.HTTP_REQUEST: {"latest": mock_node_class}} + mock_mapping.return_value = {BuiltinNodeTypes.HTTP_REQUEST: {"latest": mock_node_class}} result = workflow_service.get_default_block_config( - NodeType.HTTP_REQUEST.value, + BuiltinNodeTypes.HTTP_REQUEST, filters={HTTP_REQUEST_CONFIG_FILTER_KEY: provided_config}, ) @@ -1151,14 +1151,14 @@ class TestWorkflowService: def test_get_default_block_config_http_request_malformed_config_raises_value_error(self, workflow_service): with ( patch( - "services.workflow_service.get_workflow_node_type_classes_mapping", - return_value={NodeType.HTTP_REQUEST: {"latest": HttpRequestNode}}, + "services.workflow_service.get_node_type_classes_mapping", + return_value={BuiltinNodeTypes.HTTP_REQUEST: {"latest": HttpRequestNode}}, ), patch("services.workflow_service.LATEST_VERSION", "latest"), ): with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"): workflow_service.get_default_block_config( - NodeType.HTTP_REQUEST.value, + BuiltinNodeTypes.HTTP_REQUEST, filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"}, ) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 4042e05565..9f3874b8f1 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.variables.segments import StringSegment from dify_graph.variables.types import SegmentType from libs.uuid_utils import uuidv7 @@ -54,12 +54,12 @@ class TestDraftVariableSaver: session=mock_session, app_id=test_app_id, node_id="test_node_id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, node_execution_id="test_execution_id", user=mock_user, ) - assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False - assert saver._should_variable_be_visible("123", NodeType.START, "output") == True + assert saver._should_variable_be_visible("123_456", BuiltinNodeTypes.IF_ELSE, "output") == False + assert saver._should_variable_be_visible("123", BuiltinNodeTypes.START, "output") == True def test__normalize_variable_for_start_node(self): @dataclasses.dataclass(frozen=True) @@ -102,7 +102,7 @@ class TestDraftVariableSaver: session=mock_session, app_id=test_app_id, node_id=_NODE_ID, - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, node_execution_id="test_execution_id", user=mock_user, ) @@ -134,7 +134,7 @@ class TestDraftVariableSaver: session=mock_session, app_id="test-app-id", node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_execution_id="test-execution-id", user=mock_user, ) @@ -331,7 +331,7 @@ class TestWorkflowDraftVariableService: mock_node_config = {"type": "test_node"} with ( patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config, autospec=True), - patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM, autospec=True), + patch.object(workflow, "get_node_type_from_node_config", return_value=BuiltinNodeTypes.LLM, autospec=True), ): result = service._reset_node_var_or_sys_var(workflow, variable) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index fcdd1c2368..c890ab6a65 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -6,7 +6,7 @@ import pytest from sqlalchemy.orm import sessionmaker from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, @@ -31,7 +31,7 @@ def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfi inputs=[], user_actions=[], ).model_dump(mode="json") - node_data["type"] = NodeType.HUMAN_INPUT.value + node_data["type"] = BuiltinNodeTypes.HUMAN_INPUT return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data}) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 9ee8f88e71..ed26bcec01 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock import pytest from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction from dify_graph.nodes.human_input.enums import FormInputType from models.model import App @@ -209,7 +209,7 @@ class TestWorkflowService: workflow = MagicMock() node_config = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} ) workflow.get_node_config_by_id.return_value = node_config workflow.get_enclosing_node_type_and_id.return_value = None @@ -279,7 +279,7 @@ class TestWorkflowService: workflow = MagicMock() workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} ) service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -312,7 +312,7 @@ class TestWorkflowService: # Mock node config mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": NodeType.LLM.value}} + {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} ) mock_workflow.get_enclosing_node_type_and_id.return_value = None @@ -379,7 +379,7 @@ class TestWorkflowService: mock_workflow.environment_variables = [] mock_workflow.conversation_variables = [] mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": NodeType.LLM.value}} + {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} ) mock_workflow.get_enclosing_node_type_and_id.return_value = None diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py index 54be8379d5..a223f0119e 100644 --- a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py +++ b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py @@ -15,7 +15,7 @@ # WorkflowNodeExecution, # WorkflowNodeExecutionStatus, # ) -# from dify_graph.enums import NodeType +# from dify_graph.enums import BuiltinNodeTypes # from libs.datetime_utils import naive_utc_now # from models import WorkflowNodeExecutionModel # from models.enums import ExecutionOffLoadType @@ -41,7 +41,7 @@ # workflow_execution_id=str(uuid4()), # index=1, # node_id="test_node", -# node_type=NodeType.LLM, +# node_type=BuiltinNodeTypes.LLM, # title="Test Node", # inputs={"input_key": "input_value"}, # outputs={"output_key": "output_value"}, @@ -134,7 +134,7 @@ # workflow_execution_id=str(uuid4()), # index=1, # node_id="test_node", -# node_type=NodeType.LLM, +# node_type=BuiltinNodeTypes.LLM, # title="Test Node", # inputs=large_data, # outputs=large_data, diff --git a/dev/pyrefly-check-local b/dev/pyrefly-check-local index 80f90927bb..8fa5f121fc 100755 --- a/dev/pyrefly-check-local +++ b/dev/pyrefly-check-local @@ -10,6 +10,8 @@ EXCLUDES_FILE="api/pyrefly-local-excludes.txt" pyrefly_args=( "--summary=none" + "--use-ignore-files=false" + "--disable-project-excludes-heuristics=true" "--project-excludes=.venv" "--project-excludes=migrations/" "--project-excludes=tests/"