mirror of
https://github.com/langgenius/dify.git
synced 2026-03-21 14:28:26 +08:00
Merge branch 'feat/model-plugins-implementing' into deploy/dev
This commit is contained in:
199
.github/dependabot.yml
vendored
199
.github/dependabot.yml
vendored
@ -3,19 +3,210 @@ version: 2
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 2
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
python-dependencies:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 2
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
uv-dependencies:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
open-pull-requests-limit: 5
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
github-actions-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
2
.github/workflows/api-tests.yml
vendored
2
.github/workflows/api-tests.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
@ -32,6 +32,7 @@ from core.app.entities.queue_entities import (
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -62,7 +63,6 @@ from dify_graph.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.graph_events.graph import GraphRunAbortedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@ -303,9 +303,11 @@ class WorkflowBasedAppRunner:
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
|
||||
@ -19,6 +19,7 @@ from core.trigger.debug.events import (
|
||||
build_plugin_pool_key,
|
||||
build_webhook_pool_key,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC):
|
||||
app_id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
node_config: Mapping[str, Any]
|
||||
node_config: NodeConfigDict
|
||||
node_id: str
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str):
|
||||
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.app_id = app_id
|
||||
@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
def poll(self) -> TriggerDebugEvent | None:
|
||||
from services.trigger.trigger_service import TriggerService
|
||||
|
||||
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {}))
|
||||
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True)
|
||||
provider_id = TriggerProviderID(plugin_trigger_data.provider_id)
|
||||
pool_key: str = build_plugin_pool_key(
|
||||
name=plugin_trigger_data.event_name,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -22,7 +22,8 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
@ -31,26 +32,19 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
|
||||
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from dify_graph.nodes.code.entities import CodeLanguage
|
||||
from dify_graph.nodes.code.limits import CodeNodeLimits
|
||||
from dify_graph.nodes.datasource import DatasourceNode
|
||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
|
||||
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
|
||||
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import build_http_request_config
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@ -60,6 +54,9 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
@ -157,178 +154,128 @@ class DifyNodeFactory(NodeFactory):
|
||||
return DifyRunContext.model_validate(raw_ctx)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
:raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config["id"]
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
NodeType.CODE: lambda: {
|
||||
"code_executor": self._code_executor,
|
||||
"code_limits": self._code_limits,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: lambda: {
|
||||
"template_renderer": self._template_renderer,
|
||||
"max_output_length": self._template_transform_max_output_length,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: lambda: {
|
||||
"http_request_config": self._http_request_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
|
||||
"file_manager": self._http_request_file_manager,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: lambda: {
|
||||
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: lambda: {
|
||||
"index_processor": IndexProcessor(),
|
||||
"summary_index_service": SummaryIndex(),
|
||||
},
|
||||
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
NodeType.DATASOURCE: lambda: {
|
||||
"datasource_manager": DatasourceManager,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
|
||||
"rag_retrieval": self._rag_retrieval,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: lambda: {
|
||||
"unstructured_api_config": self._document_extractor_unstructured_api_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=False,
|
||||
),
|
||||
NodeType.TOOL: lambda: {
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
)
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config["data"]
|
||||
try:
|
||||
node_type = NodeType(node_data["type"])
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_data['type']}")
|
||||
@staticmethod
|
||||
def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData:
|
||||
"""
|
||||
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
|
||||
"""
|
||||
return node_class.validate_node_data(node_data)
|
||||
|
||||
# Get node class
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
|
||||
# Create node instance
|
||||
if node_type == NodeType.CODE:
|
||||
return CodeNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
max_output_length=self._template_transform_max_output_length,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
return HttpRequestNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
http_request_config=self._http_request_config,
|
||||
http_client=self._http_request_http_client,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HUMAN_INPUT:
|
||||
return HumanInputNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_INDEX:
|
||||
return KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
index_processor=IndexProcessor(),
|
||||
summary_index_service=SummaryIndex(),
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
return DatasourceNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
datasource_manager=DatasourceManager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DOCUMENT_EXTRACTOR:
|
||||
return DocumentExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TOOL:
|
||||
return ToolNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
def _build_llm_compatible_node_init_kwargs(
|
||||
self,
|
||||
*,
|
||||
node_class: type[Node],
|
||||
node_data: BaseNodeData,
|
||||
include_http_client: bool,
|
||||
) -> dict[str, object]:
|
||||
validated_node_data = cast(
|
||||
LLMCompatibleNodeData,
|
||||
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
|
||||
)
|
||||
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
|
||||
node_init_kwargs: dict[str, object] = {
|
||||
"credentials_provider": self._llm_credentials_provider,
|
||||
"model_factory": self._llm_model_factory,
|
||||
"model_instance": model_instance,
|
||||
"memory": self._build_memory_for_llm_node(
|
||||
node_data=validated_node_data,
|
||||
model_instance=model_instance,
|
||||
),
|
||||
}
|
||||
if include_http_client:
|
||||
node_init_kwargs["http_client"] = self._http_request_http_client
|
||||
return node_init_kwargs
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
|
||||
node_data_model = ModelConfig.model_validate(node_data["model"])
|
||||
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
|
||||
node_data_model = node_data.model
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
@ -364,14 +311,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
def _build_memory_for_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LLMCompatibleNodeData,
|
||||
model_instance: ModelInstance,
|
||||
) -> PromptMessageMemory | None:
|
||||
raw_memory_config = node_data.get("memory")
|
||||
if raw_memory_config is None:
|
||||
if node_data.memory is None:
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
@ -381,6 +326,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
node_data_memory=node_memory,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@ -11,7 +11,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph import Graph
|
||||
@ -212,7 +212,7 @@ class WorkflowEntry:
|
||||
node_config_data = node_config["data"]
|
||||
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data["type"])
|
||||
node_type = node_config_data.type
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
@ -234,8 +234,7 @@ class WorkflowEntry:
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
typed_node_config = cast(dict[str, object], node_config)
|
||||
node = cast(Any, node_factory).create_node(typed_node_config)
|
||||
node = node_factory.create_node(node_config)
|
||||
node_cls = type(node)
|
||||
|
||||
try:
|
||||
@ -371,10 +370,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config: NodeConfigDict = {
|
||||
"id": node_id,
|
||||
"data": cast(NodeConfigData, node_data),
|
||||
}
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
176
api/dify_graph/entities/base_node_data.py
Normal file
176
api/dify_graph/entities/base_node_data.py
Normal file
@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from dify_graph.entities.exc import DefaultValueTypeError
|
||||
from dify_graph.enums import ErrorStrategy, NodeType
|
||||
|
||||
# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`.
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
|
||||
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
|
||||
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
|
||||
# and persisted templates/workflows also carry undeclared compatibility keys such as
|
||||
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
|
||||
# here until graph parsing becomes discriminated by node type or those legacy payloads
|
||||
# are normalized.
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
type: NodeType
|
||||
title: str = ""
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = Field(default_factory=RetryConfig)
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""
|
||||
Dict-style access without calling model_dump() on every lookup.
|
||||
Prefer using model fields and Pydantic's extra storage.
|
||||
"""
|
||||
# First, check declared model fields
|
||||
if key in self.__class__.model_fields:
|
||||
return getattr(self, key)
|
||||
|
||||
# Then, check undeclared compatibility fields stored in Pydantic's extra dict.
|
||||
extras = getattr(self, "__pydantic_extra__", None)
|
||||
if extras is None:
|
||||
extras = getattr(self, "model_extra", None)
|
||||
if extras is not None and key in extras:
|
||||
return extras[key]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Dict-style .get() without calling model_dump() on every lookup.
|
||||
"""
|
||||
if key in self.__class__.model_fields:
|
||||
return getattr(self, key)
|
||||
|
||||
extras = getattr(self, "__pydantic_extra__", None)
|
||||
if extras is None:
|
||||
extras = getattr(self, "model_extra", None)
|
||||
if extras is not None and key in extras:
|
||||
return extras.get(key, default)
|
||||
|
||||
return default
|
||||
@ -4,21 +4,20 @@ import sys
|
||||
|
||||
from pydantic import TypeAdapter, with_config
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigData(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
data: NodeConfigData
|
||||
# This is the permissive raw graph boundary. Node factories re-validate `data`
|
||||
# with the concrete `NodeData` subtype after resolving the node implementation.
|
||||
data: BaseNodeData
|
||||
|
||||
|
||||
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Protocol, cast, final
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from libs.typing import is_str
|
||||
|
||||
@ -34,7 +34,8 @@ class NodeFactory(Protocol):
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
:raises ValueError: if node type is unknown or no implementation exists for the resolved version
|
||||
:raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation
|
||||
"""
|
||||
...
|
||||
|
||||
@ -115,10 +116,7 @@ class Graph:
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid]["data"]
|
||||
node_type = node_data["type"]
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if NodeType(node_type).is_start_node:
|
||||
if node_data.type.is_start_node:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
@ -203,6 +201,23 @@ class Graph:
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@staticmethod
|
||||
def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Remove editor-only nodes before `NodeConfigDict` validation.
|
||||
|
||||
Persisted note widgets use a top-level `type == "custom-note"` but leave
|
||||
`data.type` empty because they are never executable graph nodes. Filter
|
||||
them while configs are still raw dicts so Pydantic does not validate
|
||||
their placeholder payloads against `BaseNodeData.type: NodeType`.
|
||||
"""
|
||||
filtered_node_configs: list[dict[str, object]] = []
|
||||
for node_config in node_configs:
|
||||
if node_config.get("type", "") == "custom-note":
|
||||
continue
|
||||
filtered_node_configs.append(dict(node_config))
|
||||
return filtered_node_configs
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
@ -302,13 +317,13 @@ class Graph:
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
node_configs = cls._filter_canvas_only_nodes(node_configs)
|
||||
node_configs = _ListNodeConfigDict.validate_python(node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
|
||||
@ -374,12 +374,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AgentNodeData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
result: dict[str, Any] = {}
|
||||
typed_node_data = node_data
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
|
||||
@ -5,10 +5,12 @@ from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.AGENT
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
|
||||
@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: AnswerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AnswerNodeData.model_validate(node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
@ -3,7 +3,8 @@ from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData):
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.ANSWER
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
@ -6,6 +6,5 @@ __all__ = [
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
]
|
||||
|
||||
@ -1,31 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from dify_graph.enums import ErrorStrategy
|
||||
|
||||
from .exc import DefaultValueTypeError
|
||||
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
@ -76,112 +57,6 @@ class OutputVariableEntity(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, ge
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
@ -62,8 +64,6 @@ from dify_graph.node_events import (
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
|
||||
@ -153,11 +153,11 @@ class Node(Generic[NodeDataT]):
|
||||
Later, in __init__:
|
||||
::
|
||||
|
||||
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
config["data"] ──► _node_data_type.model_validate(..., from_attributes=True)
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
@ -241,7 +241,7 @@ class Node(Generic[NodeDataT]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
@ -254,22 +254,21 @@ class Node(Generic[NodeDataT]):
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
node_id = config["id"]
|
||||
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
self._node_data = self.validate_node_data(config["data"])
|
||||
|
||||
self.post_init()
|
||||
|
||||
@classmethod
|
||||
def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT:
|
||||
"""Validate shared graph node payloads against the subclass-declared NodeData model."""
|
||||
return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True))
|
||||
|
||||
def post_init(self) -> None:
|
||||
"""Optional hook for subclasses requiring extra initialization."""
|
||||
return
|
||||
@ -342,9 +341,6 @@ class Node(Generic[NodeDataT]):
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
@ -389,8 +385,6 @@ class Node(Generic[NodeDataT]):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
from dify_graph.nodes.agent.agent_node import AgentNode
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData
|
||||
|
||||
@ -442,7 +436,7 @@ class Node(Generic[NodeDataT]):
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
@ -480,13 +474,12 @@ class Node(Generic[NodeDataT]):
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
node_id = config["id"]
|
||||
node_data = cls.validate_node_data(config["data"])
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
)
|
||||
return data
|
||||
|
||||
@ -496,7 +489,7 @@ class Node(Generic[NodeDataT]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: NodeDataT,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from decimal import Decimal
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: CodeNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@ -3,7 +3,8 @@ from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.CODE
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
|
||||
@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
@ -181,7 +182,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: DatasourceNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -190,11 +191,10 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
typed_node_data = DatasourceNodeData.model_validate(node_data)
|
||||
result = {}
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
input = node_data.datasource_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
|
||||
@ -3,7 +3,8 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel):
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
type: NodeType = NodeType.DATASOURCE
|
||||
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.DOCUMENT_EXTRACTOR
|
||||
variable_selector: Sequence[str]
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod, file_manager
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@ -54,7 +55,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@ -136,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: DocumentExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
|
||||
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData):
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.END
|
||||
outputs: list[OutputVariableEntity]
|
||||
|
||||
|
||||
|
||||
@ -8,7 +8,8 @@ import charset_normalizer
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
|
||||
|
||||
@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.HTTP_REQUEST
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
|
||||
@ -3,6 +3,7 @@ import mimetypes
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@ -163,18 +164,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: HttpRequestNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = HttpRequestNodeData.model_validate(node_data)
|
||||
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
|
||||
if typed_node_data.body:
|
||||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
|
||||
if node_data.body:
|
||||
body_type = node_data.body.type
|
||||
data = node_data.body.data
|
||||
match body_type:
|
||||
case "none":
|
||||
pass
|
||||
|
||||
@ -10,7 +10,8 @@ from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.consts import SELECTORS_LENGTH
|
||||
@ -214,6 +215,7 @@ class UserAction(BaseModel):
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Human Input node data."""
|
||||
|
||||
type: NodeType = NodeType.HUMAN_INPUT
|
||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
||||
form_content: str = ""
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import (
|
||||
@ -63,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository,
|
||||
@ -348,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: HumanInputNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input default values.
|
||||
@ -357,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input default values
|
||||
"""
|
||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
return node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
|
||||
@ -2,7 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
|
||||
|
||||
@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData):
|
||||
If Else Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.IF_ELSE
|
||||
|
||||
class Case(BaseModel):
|
||||
"""
|
||||
Case entity representing a single logical condition group
|
||||
|
||||
@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IfElseNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in typed_node_data.cases or []:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
for case in node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
||||
@ -3,7 +3,9 @@ from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
class ErrorHandleMode(StrEnum):
|
||||
@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.ITERATION
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData):
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.ITERATION_START
|
||||
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@ -460,21 +461,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: IterationNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IterationNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
node_config_data = node.get("data", {})
|
||||
if node_config_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
@ -490,14 +488,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
# Get node class
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
|
||||
@ -3,7 +3,8 @@ from typing import Literal, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
Knowledge index Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-index"
|
||||
type: NodeType = NodeType.KNOWLEDGE_INDEX
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@ -30,7 +31,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
index_processor: IndexProcessorProtocol,
|
||||
|
||||
@ -3,7 +3,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-retrieval"
|
||||
type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
query_variable_selector: list[str] | None | str = None
|
||||
query_attachment_selector: list[str] | None | str = None
|
||||
dataset_ids: list[str]
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@ -49,7 +50,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
@ -301,15 +302,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
if typed_node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
if typed_node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||
if node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
if node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector
|
||||
return variable_mapping
|
||||
|
||||
@ -3,7 +3,8 @@ from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
@ -62,6 +63,7 @@ class ExtractConfig(BaseModel):
|
||||
|
||||
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.LIST_OPERATOR
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderByConfig
|
||||
|
||||
@ -4,8 +4,9 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.LLM
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
|
||||
@ -21,6 +21,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@ -121,7 +122,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
@ -954,14 +955,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LLMNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
prompt_template = node_data.prompt_template
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
@ -979,7 +977,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
memory = typed_node_data.memory
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
@ -987,16 +985,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if typed_node_data.context.enabled:
|
||||
variable_mapping["#context#"] = typed_node_data.context.variable_selector
|
||||
if node_data.context.enabled:
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if typed_node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
|
||||
if typed_node_data.memory:
|
||||
if node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
|
||||
|
||||
if typed_node_data.prompt_config:
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
@ -1009,7 +1007,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
break
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
@ -3,7 +3,9 @@ from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@ -39,6 +41,7 @@ class LoopVariableData(BaseModel):
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
type: NodeType = NodeType.LOOP
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData):
|
||||
Loop Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.LOOP_START
|
||||
|
||||
|
||||
class LoopEndNodeData(BaseNodeData):
|
||||
@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData):
|
||||
Loop End Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.LOOP_END
|
||||
|
||||
|
||||
class LoopState(BaseLoopState):
|
||||
|
||||
@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LoopNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
@ -320,14 +318,15 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
# Get node class
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
@ -342,7 +341,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in typed_node_data.loop_variables or []:
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
|
||||
@ -8,7 +8,8 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
Parameter Extractor Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.PARAMETER_EXTRACTOR
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
|
||||
@ -10,6 +10,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@ -837,15 +838,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: ParameterExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
|
||||
|
||||
if typed_node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
|
||||
if node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
|
||||
for selector in selectors:
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@ -11,6 +12,7 @@ class ClassConfig(BaseModel):
|
||||
|
||||
|
||||
class QuestionClassifierNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.QUESTION_CLASSIFIER
|
||||
query_variable_selector: list[str]
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
|
||||
@ -7,6 +7,7 @@ from core.model_manager import ModelInstance
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@ -62,7 +63,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@ -251,16 +252,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: QuestionClassifierNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {"query": typed_node_data.query_variable_selector}
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors: list[VariableSelector] = []
|
||||
if typed_node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
|
||||
|
||||
@ -2,7 +2,8 @@ from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData):
|
||||
Start Node Data
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.START
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData):
|
||||
Template Transform Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.TEMPLATE_TRANSFORM
|
||||
variables: list[VariableSelector]
|
||||
template: str
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@ -25,7 +26,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@ -86,12 +87,9 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@ -4,7 +4,8 @@ from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
@ -32,6 +33,8 @@ class ToolEntity(BaseModel):
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
type: NodeType = NodeType.TOOL
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
|
||||
@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@ -46,7 +47,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@ -484,7 +485,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: ToolNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -493,9 +494,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
typed_node_data = node_data
|
||||
result = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
|
||||
@ -4,13 +4,16 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.trigger.entities.entities import EventParameter
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError
|
||||
|
||||
|
||||
class TriggerEventNodeData(BaseNodeData):
|
||||
"""Plugin trigger node data"""
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_PLUGIN
|
||||
|
||||
class TriggerEventInput(BaseModel):
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
@ -38,8 +41,6 @@ class TriggerEventNodeData(BaseNodeData):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return type
|
||||
|
||||
title: str
|
||||
desc: str | None = None
|
||||
plugin_id: str = Field(..., description="Plugin ID")
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
event_name: str = Field(..., description="Event name")
|
||||
|
||||
@ -2,7 +2,8 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class TriggerScheduleNodeData(BaseNodeData):
|
||||
@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData):
|
||||
Trigger Schedule Node Data
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_SCHEDULE
|
||||
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
|
||||
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
|
||||
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from dify_graph.nodes.base.exc import BaseNodeError
|
||||
from dify_graph.entities.exc import BaseNodeError
|
||||
|
||||
|
||||
class ScheduleNodeError(BaseNodeError):
|
||||
|
||||
@ -1,10 +1,41 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
}
|
||||
)
|
||||
|
||||
_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
}
|
||||
)
|
||||
|
||||
_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES
|
||||
|
||||
_WEBHOOK_BODY_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.FILE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Method(StrEnum):
|
||||
@ -25,29 +56,34 @@ class ContentType(StrEnum):
|
||||
|
||||
|
||||
class WebhookParameter(BaseModel):
|
||||
"""Parameter definition for headers, query params, or body."""
|
||||
"""Parameter definition for headers or query params."""
|
||||
|
||||
name: str
|
||||
type: SegmentType = SegmentType.STRING
|
||||
required: bool = False
|
||||
|
||||
@field_validator("type", mode="after")
|
||||
@classmethod
|
||||
def validate_type(cls, v: SegmentType) -> SegmentType:
|
||||
if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook parameter type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class WebhookBodyParameter(BaseModel):
|
||||
"""Body parameter with type information."""
|
||||
|
||||
name: str
|
||||
type: Literal[
|
||||
"string",
|
||||
"number",
|
||||
"boolean",
|
||||
"object",
|
||||
"array[string]",
|
||||
"array[number]",
|
||||
"array[boolean]",
|
||||
"array[object]",
|
||||
"file",
|
||||
] = "string"
|
||||
type: SegmentType = SegmentType.STRING
|
||||
required: bool = False
|
||||
|
||||
@field_validator("type", mode="after")
|
||||
@classmethod
|
||||
def validate_type(cls, v: SegmentType) -> SegmentType:
|
||||
if v not in _WEBHOOK_BODY_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook body parameter type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class WebhookData(BaseNodeData):
|
||||
"""
|
||||
@ -57,6 +93,7 @@ class WebhookData(BaseNodeData):
|
||||
class SyncMode(StrEnum):
|
||||
SYNC = "async" # only support
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_WEBHOOK
|
||||
method: Method = Method.GET
|
||||
content_type: ContentType = Field(default=ContentType.JSON)
|
||||
headers: Sequence[WebhookParameter] = Field(default_factory=list)
|
||||
@ -71,6 +108,22 @@ class WebhookData(BaseNodeData):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
@field_validator("headers", mode="after")
|
||||
@classmethod
|
||||
def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
|
||||
for param in v:
|
||||
if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook header parameter type: {param.type}")
|
||||
return v
|
||||
|
||||
@field_validator("params", mode="after")
|
||||
@classmethod
|
||||
def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
|
||||
for param in v:
|
||||
if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook query parameter type: {param.type}")
|
||||
return v
|
||||
|
||||
status_code: int = 200 # Expected status code for response
|
||||
response_body: str = "" # Template for response body
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from dify_graph.nodes.base.exc import BaseNodeError
|
||||
from dify_graph.entities.exc import BaseNodeError
|
||||
|
||||
|
||||
class WebhookNodeError(BaseNodeError):
|
||||
|
||||
@ -152,7 +152,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs[param_name] = raw_data
|
||||
continue
|
||||
|
||||
if param_type == "file":
|
||||
if param_type == SegmentType.FILE:
|
||||
# Get File object (already processed by webhook controller)
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
|
||||
@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData):
|
||||
Variable Aggregator Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.VARIABLE_AGGREGATOR
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: AdvancedSettings | None = None
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@ -22,7 +23,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
@ -52,21 +53,18 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: VariableAssignerData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerData.model_validate(node_data)
|
||||
|
||||
mapping = {}
|
||||
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
|
||||
assigned_variable_node_id = node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(typed_node_data.assigned_variable_selector)
|
||||
selector_key = ".".join(node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = typed_node_data.assigned_variable_selector
|
||||
mapping[key] = node_data.assigned_variable_selector
|
||||
|
||||
selector_key = ".".join(typed_node_data.input_variable_selector)
|
||||
selector_key = ".".join(node_data.input_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = typed_node_data.input_variable_selector
|
||||
mapping[key] = node_data.input_variable_selector
|
||||
return mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class WriteMode(StrEnum):
|
||||
@ -11,6 +12,7 @@ class WriteMode(StrEnum):
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
type: NodeType = NodeType.VARIABLE_ASSIGNER
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
|
||||
@ -3,7 +3,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
from .enums import InputType, Operation
|
||||
|
||||
@ -22,5 +23,6 @@ class VariableOperationItem(BaseModel):
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.VARIABLE_ASSIGNER
|
||||
version: str = "2"
|
||||
items: Sequence[VariableOperationItem] = Field(default_factory=list)
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@ -56,7 +57,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
@ -94,13 +95,10 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: VariableAssignerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, Sequence[str]] = {}
|
||||
for item in typed_node_data.items:
|
||||
for item in node_data.items:
|
||||
_target_mapping_from_item(var_mapping, node_id, item)
|
||||
_source_mapping_from_item(var_mapping, node_id, item)
|
||||
return var_mapping
|
||||
|
||||
@ -7,7 +7,7 @@ from celery.signals import worker_init
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
from opentelemetry.propagators.b3 import B3Format
|
||||
from opentelemetry.propagators.b3 import B3MultiFormat
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
@ -24,7 +24,7 @@ def setup_context_propagation() -> None:
|
||||
CompositePropagator(
|
||||
[
|
||||
TraceContextTextMapPropagator(),
|
||||
B3Format(),
|
||||
B3MultiFormat(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@ -233,8 +233,11 @@ class Workflow(Base): # bug
|
||||
|
||||
def get_node_config_by_id(self, node_id: str) -> NodeConfigDict:
|
||||
"""Extract a node configuration from the workflow graph by node ID.
|
||||
A node configuration is a dictionary containing the node's properties, including
|
||||
the node's id, title, and its data as a dict.
|
||||
|
||||
A node configuration includes the node id and a typed `BaseNodeData` for `data`.
|
||||
`BaseNodeData` keeps a dict-like `get`/`__getitem__` compatibility layer backed by
|
||||
model fields plus Pydantic extra storage for legacy consumers, but callers should
|
||||
prefer attribute access.
|
||||
"""
|
||||
workflow_graph = self.graph_dict
|
||||
|
||||
@ -252,12 +255,9 @@ class Workflow(Base): # bug
|
||||
return NodeConfigDictAdapter.validate_python(node_config)
|
||||
|
||||
@staticmethod
|
||||
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
|
||||
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
|
||||
"""Extract type of a node from the node configuration returned by `get_node_config_by_id`."""
|
||||
node_config_data = node_config.get("data", {})
|
||||
# Get node class
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
return node_type
|
||||
return node_config["data"].type
|
||||
|
||||
@staticmethod
|
||||
def get_enclosing_node_type_and_id(
|
||||
|
||||
@ -5,42 +5,42 @@ requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
"aliyun-log-python-sdk~=0.9.37",
|
||||
"arize-phoenix-otel~=0.9.2",
|
||||
"azure-identity==1.16.1",
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.2",
|
||||
"beautifulsoup4==4.12.2",
|
||||
"boto3==1.35.99",
|
||||
"boto3==1.42.65",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.5.2",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.18",
|
||||
"flask-compress>=1.17,<1.24",
|
||||
"flask-cors~=6.0.0",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.0.7",
|
||||
"flask-migrate~=4.1.0",
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.189.0",
|
||||
"google-api-python-client==2.192.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-auth-httplib2==0.2.0",
|
||||
"google-auth-httplib2==0.3.0",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
"googleapis-common-protos>=1.65.0",
|
||||
"gunicorn~=23.0.0",
|
||||
"gunicorn~=25.1.0",
|
||||
"httpx[socks]~=0.28.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.55.1",
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"langsmith~=0.7.16",
|
||||
"markdown~=3.8.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.8.72",
|
||||
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.1", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
@ -53,7 +53,7 @@ dependencies = [
|
||||
"opentelemetry-instrumentation-httpx==0.49b0",
|
||||
"opentelemetry-instrumentation-redis==0.49b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
|
||||
"opentelemetry-propagator-b3==1.28.0",
|
||||
"opentelemetry-propagator-b3==1.40.0",
|
||||
"opentelemetry-proto==1.28.0",
|
||||
"opentelemetry-sdk==1.28.0",
|
||||
"opentelemetry-semantic-conventions==0.49b0",
|
||||
@ -63,21 +63,21 @@ dependencies = [
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"pycryptodome==3.23.0",
|
||||
"pydantic~=2.12.5",
|
||||
"pydantic-extra-types~=2.10.3",
|
||||
"pydantic-settings~=2.12.0",
|
||||
"pydantic-extra-types~=2.11.0",
|
||||
"pydantic-settings~=2.13.1",
|
||||
"pyjwt~=2.11.0",
|
||||
"pypdfium2==5.2.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.2.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.9.0",
|
||||
"sentry-sdk[flask]~=2.28.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.49.1",
|
||||
"tiktoken~=0.9.0",
|
||||
"transformers~=4.56.1",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||
"yarl~=1.18.3",
|
||||
"webvtt-py~=0.5.1",
|
||||
@ -109,46 +109,46 @@ package = false
|
||||
# Required for development and running tests
|
||||
############################################################
|
||||
dev = [
|
||||
"coverage~=7.2.4",
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=38.2.0",
|
||||
"coverage~=7.13.4",
|
||||
"dotenv-linter~=0.7.0",
|
||||
"faker~=40.8.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"basedpyright~=1.38.2",
|
||||
"ruff~=0.14.0",
|
||||
"pytest~=8.3.2",
|
||||
"pytest-benchmark~=4.0.0",
|
||||
"pytest-cov~=4.1.0",
|
||||
"ruff~=0.15.5",
|
||||
"pytest~=9.0.2",
|
||||
"pytest-benchmark~=5.2.3",
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.14.0",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.13.2",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=5.5.0",
|
||||
"types-cachetools~=6.2.0",
|
||||
"types-colorama~=0.4.15",
|
||||
"types-defusedxml~=0.7.0",
|
||||
"types-deprecated~=1.2.15",
|
||||
"types-docutils~=0.21.0",
|
||||
"types-jsonschema~=4.23.0",
|
||||
"types-flask-cors~=5.0.0",
|
||||
"types-deprecated~=1.3.1",
|
||||
"types-docutils~=0.22.3",
|
||||
"types-jsonschema~=4.26.0",
|
||||
"types-flask-cors~=6.0.0",
|
||||
"types-flask-migrate~=4.1.0",
|
||||
"types-gevent~=25.9.0",
|
||||
"types-greenlet~=3.3.0",
|
||||
"types-html5lib~=1.1.11",
|
||||
"types-markdown~=3.10.2",
|
||||
"types-oauthlib~=3.2.0",
|
||||
"types-oauthlib~=3.3.0",
|
||||
"types-objgraph~=3.6.0",
|
||||
"types-olefile~=0.47.0",
|
||||
"types-openpyxl~=3.1.5",
|
||||
"types-pexpect~=4.9.0",
|
||||
"types-protobuf~=5.29.1",
|
||||
"types-protobuf~=6.32.1",
|
||||
"types-psutil~=7.2.2",
|
||||
"types-psycopg2~=2.9.21",
|
||||
"types-pygments~=2.19.0",
|
||||
"types-pymysql~=1.1.0",
|
||||
"types-python-dateutil~=2.9.0",
|
||||
"types-pywin32~=310.0.0",
|
||||
"types-pywin32~=311.0.0",
|
||||
"types-pyyaml~=6.0.12",
|
||||
"types-regex~=2024.11.6",
|
||||
"types-regex~=2026.2.28",
|
||||
"types-shapely~=2.1.0",
|
||||
"types-simplejson>=3.20.0",
|
||||
"types-six>=1.17.0",
|
||||
@ -161,7 +161,7 @@ dev = [
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=1.17.0",
|
||||
"types_setuptools>=80.9.0",
|
||||
"pandas-stubs~=2.2.3",
|
||||
"pandas-stubs~=3.0.0",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"types-python-http-client>=3.3.7.20240910",
|
||||
"import-linter>=2.3",
|
||||
@ -180,13 +180,13 @@ dev = [
|
||||
# Required for storage clients
|
||||
############################################################
|
||||
storage = [
|
||||
"azure-storage-blob==12.26.0",
|
||||
"azure-storage-blob==12.28.0",
|
||||
"bce-python-sdk~=0.9.23",
|
||||
"cos-python-sdk-v5==1.9.38",
|
||||
"esdk-obs-python==3.25.8",
|
||||
"cos-python-sdk-v5==1.9.41",
|
||||
"esdk-obs-python==3.26.2",
|
||||
"google-cloud-storage>=3.0.0",
|
||||
"opendal~=0.46.0",
|
||||
"oss2==2.18.5",
|
||||
"oss2==2.19.1",
|
||||
"supabase~=2.18.1",
|
||||
"tos~=2.9.0",
|
||||
]
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||
from dify_graph.nodes.trigger_schedule.entities import (
|
||||
ScheduleConfig,
|
||||
SchedulePlanUpdate,
|
||||
TriggerScheduleNodeData,
|
||||
VisualConfig,
|
||||
)
|
||||
from dify_graph.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
@ -176,26 +180,26 @@ class ScheduleService:
|
||||
return next_run_at
|
||||
|
||||
@staticmethod
|
||||
def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig:
|
||||
def to_schedule_config(node_config: NodeConfigDict) -> ScheduleConfig:
|
||||
"""
|
||||
Converts user-friendly visual schedule settings to cron expression.
|
||||
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
|
||||
"""
|
||||
node_data = node_config.get("data", {})
|
||||
mode = node_data.get("mode", "visual")
|
||||
timezone = node_data.get("timezone", "UTC")
|
||||
node_id = node_config.get("id", "start")
|
||||
node_data = TriggerScheduleNodeData.model_validate(node_config["data"], from_attributes=True)
|
||||
mode = node_data.mode
|
||||
timezone = node_data.timezone
|
||||
node_id = node_config["id"]
|
||||
|
||||
cron_expression = None
|
||||
if mode == "cron":
|
||||
cron_expression = node_data.get("cron_expression")
|
||||
cron_expression = node_data.cron_expression
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||
elif mode == "visual":
|
||||
frequency = str(node_data.get("frequency"))
|
||||
frequency = str(node_data.frequency or "")
|
||||
if not frequency:
|
||||
raise ScheduleConfigError("Frequency is required for visual mode")
|
||||
visual_config = VisualConfig(**node_data.get("visual_config", {}))
|
||||
visual_config = VisualConfig.model_validate(node_data.visual_config or {})
|
||||
cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config)
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for visual mode")
|
||||
@ -239,19 +243,21 @@ class ScheduleService:
|
||||
if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value:
|
||||
continue
|
||||
|
||||
mode = node_data.get("mode", "visual")
|
||||
timezone = node_data.get("timezone", "UTC")
|
||||
node_id = node.get("id", "start")
|
||||
trigger_data = TriggerScheduleNodeData.model_validate(node_data)
|
||||
mode = trigger_data.mode
|
||||
timezone = trigger_data.timezone
|
||||
|
||||
cron_expression = None
|
||||
if mode == "cron":
|
||||
cron_expression = node_data.get("cron_expression")
|
||||
cron_expression = trigger_data.cron_expression
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||
elif mode == "visual":
|
||||
frequency = node_data.get("frequency")
|
||||
visual_config_dict = node_data.get("visual_config", {})
|
||||
visual_config = VisualConfig(**visual_config_dict)
|
||||
frequency = trigger_data.frequency
|
||||
if not frequency:
|
||||
raise ScheduleConfigError("Frequency is required for visual mode")
|
||||
visual_config = VisualConfig.model_validate(trigger_data.visual_config or {})
|
||||
cron_expression = ScheduleService.visual_to_cron(frequency, visual_config)
|
||||
else:
|
||||
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
|
||||
|
||||
@ -16,6 +16,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from extensions.ext_database import db
|
||||
@ -41,7 +42,7 @@ class TriggerService:
|
||||
|
||||
@classmethod
|
||||
def invoke_trigger_event(
|
||||
cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent
|
||||
cls, tenant_id: str, user_id: str, node_config: NodeConfigDict, event: PluginTriggerDebugEvent
|
||||
) -> TriggerInvokeEventResponse:
|
||||
"""Invoke a trigger event."""
|
||||
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
|
||||
@ -50,7 +51,7 @@ class TriggerService:
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError("Subscription not found")
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {}))
|
||||
node_data = TriggerEventNodeData.model_validate(node_config["data"], from_attributes=True)
|
||||
request = TriggerHttpRequestCachingService.get_request(event.request_id)
|
||||
payload = TriggerHttpRequestCachingService.get_payload(event.request_id)
|
||||
# invoke triger
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import secrets
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
@ -16,9 +16,16 @@ from werkzeug.exceptions import RequestEntityTooLarge
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.file.models import FileTransferMethod
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from dify_graph.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from dify_graph.variables.types import ArrayValidation, SegmentType
|
||||
from enums.quota_type import QuotaType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -57,7 +64,7 @@ class WebhookService:
|
||||
@classmethod
|
||||
def get_webhook_trigger_and_workflow(
|
||||
cls, webhook_id: str, is_debug: bool = False
|
||||
) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]:
|
||||
) -> tuple[WorkflowWebhookTrigger, Workflow, NodeConfigDict]:
|
||||
"""Get webhook trigger, workflow, and node configuration.
|
||||
|
||||
Args:
|
||||
@ -135,7 +142,7 @@ class WebhookService:
|
||||
|
||||
@classmethod
|
||||
def extract_and_validate_webhook_data(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any]
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
|
||||
) -> dict[str, Any]:
|
||||
"""Extract and validate webhook data in a single unified process.
|
||||
|
||||
@ -153,7 +160,7 @@ class WebhookService:
|
||||
raw_data = cls.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate HTTP metadata (method, content-type)
|
||||
node_data = node_config.get("data", {})
|
||||
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
|
||||
validation_result = cls._validate_http_metadata(raw_data, node_data)
|
||||
if not validation_result["valid"]:
|
||||
raise ValueError(validation_result["error"])
|
||||
@ -192,7 +199,7 @@ class WebhookService:
|
||||
content_type = cls._extract_content_type(dict(request.headers))
|
||||
|
||||
# Route to appropriate extractor based on content type
|
||||
extractors = {
|
||||
extractors: dict[str, Callable[[], tuple[dict[str, Any], dict[str, Any]]]] = {
|
||||
"application/json": cls._extract_json_body,
|
||||
"application/x-www-form-urlencoded": cls._extract_form_body,
|
||||
"multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger),
|
||||
@ -214,7 +221,7 @@ class WebhookService:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
|
||||
"""Process and validate webhook data according to node configuration.
|
||||
|
||||
Args:
|
||||
@ -230,18 +237,13 @@ class WebhookService:
|
||||
result = raw_data.copy()
|
||||
|
||||
# Validate and process headers
|
||||
cls._validate_required_headers(raw_data["headers"], node_data.get("headers", []))
|
||||
cls._validate_required_headers(raw_data["headers"], node_data.headers)
|
||||
|
||||
# Process query parameters with type conversion and validation
|
||||
result["query_params"] = cls._process_parameters(
|
||||
raw_data["query_params"], node_data.get("params", []), is_form_data=True
|
||||
)
|
||||
result["query_params"] = cls._process_parameters(raw_data["query_params"], node_data.params, is_form_data=True)
|
||||
|
||||
# Process body parameters based on content type
|
||||
configured_content_type = node_data.get("content_type", "application/json").lower()
|
||||
result["body"] = cls._process_body_parameters(
|
||||
raw_data["body"], node_data.get("body", []), configured_content_type
|
||||
)
|
||||
result["body"] = cls._process_body_parameters(raw_data["body"], node_data.body, node_data.content_type)
|
||||
|
||||
return result
|
||||
|
||||
@ -424,7 +426,11 @@ class WebhookService:
|
||||
|
||||
@classmethod
|
||||
def _process_parameters(
|
||||
cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False
|
||||
cls,
|
||||
raw_params: dict[str, str],
|
||||
param_configs: Sequence[WebhookParameter],
|
||||
*,
|
||||
is_form_data: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Process parameters with unified validation and type conversion.
|
||||
|
||||
@ -440,13 +446,13 @@ class WebhookService:
|
||||
ValueError: If required parameters are missing or validation fails
|
||||
"""
|
||||
processed = {}
|
||||
configured_params = {config.get("name", ""): config for config in param_configs}
|
||||
configured_params = {config.name: config for config in param_configs}
|
||||
|
||||
# Process configured parameters
|
||||
for param_config in param_configs:
|
||||
name = param_config.get("name", "")
|
||||
param_type = param_config.get("type", SegmentType.STRING)
|
||||
required = param_config.get("required", False)
|
||||
name = param_config.name
|
||||
param_type = param_config.type
|
||||
required = param_config.required
|
||||
|
||||
# Check required parameters
|
||||
if required and name not in raw_params:
|
||||
@ -465,7 +471,10 @@ class WebhookService:
|
||||
|
||||
@classmethod
|
||||
def _process_body_parameters(
|
||||
cls, raw_body: dict[str, Any], body_configs: list, content_type: str
|
||||
cls,
|
||||
raw_body: dict[str, Any],
|
||||
body_configs: Sequence[WebhookBodyParameter],
|
||||
content_type: ContentType,
|
||||
) -> dict[str, Any]:
|
||||
"""Process body parameters based on content type and configuration.
|
||||
|
||||
@ -480,25 +489,28 @@ class WebhookService:
|
||||
Raises:
|
||||
ValueError: If required body parameters are missing or validation fails
|
||||
"""
|
||||
if content_type in ["text/plain", "application/octet-stream"]:
|
||||
# For text/plain and octet-stream, validate required content exists
|
||||
if body_configs and any(config.get("required", False) for config in body_configs):
|
||||
raw_content = raw_body.get("raw")
|
||||
if not raw_content:
|
||||
raise ValueError(f"Required body content missing for {content_type} request")
|
||||
return raw_body
|
||||
match content_type:
|
||||
case ContentType.TEXT | ContentType.BINARY:
|
||||
# For text/plain and octet-stream, validate required content exists
|
||||
if body_configs and any(config.required for config in body_configs):
|
||||
raw_content = raw_body.get("raw")
|
||||
if not raw_content:
|
||||
raise ValueError(f"Required body content missing for {content_type} request")
|
||||
return raw_body
|
||||
case _:
|
||||
pass
|
||||
|
||||
# For structured data (JSON, form-data, etc.)
|
||||
processed = {}
|
||||
configured_params = {config.get("name", ""): config for config in body_configs}
|
||||
configured_params: dict[str, WebhookBodyParameter] = {config.name: config for config in body_configs}
|
||||
|
||||
for body_config in body_configs:
|
||||
name = body_config.get("name", "")
|
||||
param_type = body_config.get("type", SegmentType.STRING)
|
||||
required = body_config.get("required", False)
|
||||
name = body_config.name
|
||||
param_type = body_config.type
|
||||
required = body_config.required
|
||||
|
||||
# Handle file parameters for multipart data
|
||||
if param_type == SegmentType.FILE and content_type == "multipart/form-data":
|
||||
if param_type == SegmentType.FILE and content_type == ContentType.FORM_DATA:
|
||||
# File validation is handled separately in extract phase
|
||||
continue
|
||||
|
||||
@ -508,7 +520,7 @@ class WebhookService:
|
||||
|
||||
if name in raw_body:
|
||||
raw_value = raw_body[name]
|
||||
is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"]
|
||||
is_form_data = content_type in [ContentType.FORM_URLENCODED, ContentType.FORM_DATA]
|
||||
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
|
||||
|
||||
# Include unconfigured parameters
|
||||
@ -519,7 +531,9 @@ class WebhookService:
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any:
|
||||
def _validate_and_convert_value(
|
||||
cls, param_name: str, value: Any, param_type: SegmentType | str, is_form_data: bool
|
||||
) -> Any:
|
||||
"""Unified validation and type conversion for parameter values.
|
||||
|
||||
Args:
|
||||
@ -532,7 +546,8 @@ class WebhookService:
|
||||
Any: The validated and converted value
|
||||
|
||||
Raises:
|
||||
ValueError: If validation or conversion fails
|
||||
ValueError: If validation or conversion fails. The original validation
|
||||
error is preserved as ``__cause__`` for debugging.
|
||||
"""
|
||||
try:
|
||||
if is_form_data:
|
||||
@ -542,10 +557,10 @@ class WebhookService:
|
||||
# JSON data should already be in correct types, just validate
|
||||
return cls._validate_json_value(param_name, value, param_type)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}")
|
||||
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}") from e
|
||||
|
||||
@classmethod
|
||||
def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any:
|
||||
def _convert_form_value(cls, param_name: str, value: str, param_type: SegmentType | str) -> Any:
|
||||
"""Convert form data string values to specified types.
|
||||
|
||||
Args:
|
||||
@ -576,7 +591,7 @@ class WebhookService:
|
||||
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
|
||||
|
||||
@classmethod
|
||||
def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any:
|
||||
def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any:
|
||||
"""Validate JSON values against expected types.
|
||||
|
||||
Args:
|
||||
@ -590,43 +605,43 @@ class WebhookService:
|
||||
Raises:
|
||||
ValueError: If the value type doesn't match the expected type
|
||||
"""
|
||||
type_validators = {
|
||||
SegmentType.STRING: (lambda v: isinstance(v, str), "string"),
|
||||
SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"),
|
||||
SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"),
|
||||
SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"),
|
||||
SegmentType.ARRAY_STRING: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v),
|
||||
"array of strings",
|
||||
),
|
||||
SegmentType.ARRAY_NUMBER: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v),
|
||||
"array of numbers",
|
||||
),
|
||||
SegmentType.ARRAY_BOOLEAN: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v),
|
||||
"array of booleans",
|
||||
),
|
||||
SegmentType.ARRAY_OBJECT: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v),
|
||||
"array of objects",
|
||||
),
|
||||
}
|
||||
|
||||
validator_info = type_validators.get(SegmentType(param_type))
|
||||
if not validator_info:
|
||||
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
|
||||
param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name)
|
||||
if param_type_enum is None:
|
||||
return value
|
||||
|
||||
validator, expected_type = validator_info
|
||||
if not validator(value):
|
||||
if not param_type_enum.is_valid(value, array_validation=ArrayValidation.ALL):
|
||||
actual_type = type(value).__name__
|
||||
expected_type = cls._expected_type_label(param_type_enum)
|
||||
raise ValueError(f"Expected {expected_type}, got {actual_type}")
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None:
|
||||
def _coerce_segment_type(cls, param_type: SegmentType | str, *, param_name: str) -> SegmentType | None:
|
||||
if isinstance(param_type, SegmentType):
|
||||
return param_type
|
||||
try:
|
||||
return SegmentType(param_type)
|
||||
except Exception:
|
||||
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _expected_type_label(param_type: SegmentType) -> str:
|
||||
match param_type:
|
||||
case SegmentType.ARRAY_STRING:
|
||||
return "array of strings"
|
||||
case SegmentType.ARRAY_NUMBER:
|
||||
return "array of numbers"
|
||||
case SegmentType.ARRAY_BOOLEAN:
|
||||
return "array of booleans"
|
||||
case SegmentType.ARRAY_OBJECT:
|
||||
return "array of objects"
|
||||
case _:
|
||||
return param_type.value
|
||||
|
||||
@classmethod
|
||||
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: Sequence[WebhookParameter]) -> None:
|
||||
"""Validate required headers are present.
|
||||
|
||||
Args:
|
||||
@ -639,14 +654,14 @@ class WebhookService:
|
||||
headers_lower = {k.lower(): v for k, v in headers.items()}
|
||||
headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()}
|
||||
for header_config in header_configs:
|
||||
if header_config.get("required", False):
|
||||
header_name = header_config.get("name", "")
|
||||
if header_config.required:
|
||||
header_name = header_config.name
|
||||
sanitized_name = cls._sanitize_key(header_name).lower()
|
||||
if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized:
|
||||
raise ValueError(f"Required header missing: {header_name}")
|
||||
|
||||
@classmethod
|
||||
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
|
||||
"""Validate HTTP method and content-type.
|
||||
|
||||
Args:
|
||||
@ -657,13 +672,13 @@ class WebhookService:
|
||||
dict[str, Any]: Validation result with 'valid' key and optional 'error' key
|
||||
"""
|
||||
# Validate HTTP method
|
||||
configured_method = node_data.get("method", "get").upper()
|
||||
configured_method = node_data.method.value.upper()
|
||||
request_method = webhook_data["method"].upper()
|
||||
if configured_method != request_method:
|
||||
return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}")
|
||||
|
||||
# Validate Content-type
|
||||
configured_content_type = node_data.get("content_type", "application/json").lower()
|
||||
configured_content_type = node_data.content_type.value.lower()
|
||||
request_content_type = cls._extract_content_type(webhook_data["headers"])
|
||||
|
||||
if configured_content_type != request_content_type:
|
||||
@ -788,7 +803,7 @@ class WebhookService:
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]:
|
||||
def generate_webhook_response(cls, node_config: NodeConfigDict) -> tuple[dict[str, Any], int]:
|
||||
"""Generate HTTP response based on node configuration.
|
||||
|
||||
Args:
|
||||
@ -797,11 +812,11 @@ class WebhookService:
|
||||
Returns:
|
||||
tuple[dict[str, Any], int]: Response data and HTTP status code
|
||||
"""
|
||||
node_data = node_config.get("data", {})
|
||||
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
|
||||
|
||||
# Get configured status code and response body
|
||||
status_code = node_data.get("status_code", 200)
|
||||
response_body = node_data.get("response_body", "")
|
||||
status_code = node_data.status_code
|
||||
response_body = node_data.response_body
|
||||
|
||||
# Parse response body as JSON if it's valid JSON, otherwise return as text
|
||||
try:
|
||||
|
||||
@ -16,6 +16,7 @@ from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
@ -693,7 +694,7 @@ class WorkflowService:
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
node_data = node_config.get("data", {})
|
||||
node_data = node_config["data"]
|
||||
if node_type.is_start_node:
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
@ -703,7 +704,7 @@ class WorkflowService:
|
||||
workflow=draft_workflow,
|
||||
)
|
||||
if node_type is NodeType.START:
|
||||
start_data = StartNodeData.model_validate(node_data)
|
||||
start_data = StartNodeData.model_validate(node_data, from_attributes=True)
|
||||
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
|
||||
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
|
||||
)
|
||||
@ -941,7 +942,7 @@ class WorkflowService:
|
||||
if node_type is not NodeType.HUMAN_INPUT:
|
||||
raise ValueError("Node type must be human-input.")
|
||||
|
||||
node_data = HumanInputNodeData.model_validate(node_config.get("data", {}))
|
||||
node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True)
|
||||
delivery_method = self._resolve_human_input_delivery_method(
|
||||
node_data=node_data,
|
||||
delivery_method_id=delivery_method_id,
|
||||
@ -1059,7 +1060,7 @@ class WorkflowService:
|
||||
*,
|
||||
workflow: Workflow,
|
||||
account: Account,
|
||||
node_config: Mapping[str, Any],
|
||||
node_config: NodeConfigDict,
|
||||
variable_pool: VariablePool,
|
||||
) -> HumanInputNode:
|
||||
graph_init_params = GraphInitParams(
|
||||
@ -1079,7 +1080,7 @@ class WorkflowService:
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
node = HumanInputNode(
|
||||
id=node_config.get("id", str(uuid.uuid4())),
|
||||
id=node_config["id"],
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@ -1092,7 +1093,7 @@ class WorkflowService:
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_config: Mapping[str, Any],
|
||||
node_config: NodeConfigDict,
|
||||
manual_inputs: Mapping[str, Any],
|
||||
) -> VariablePool:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
|
||||
|
||||
@ -189,6 +189,7 @@ def test_custom_authorization_header(setup_http_mock):
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
@ -209,6 +210,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
|
||||
# Create node data with custom auth and empty api_key
|
||||
node_data = HttpRequestNodeData(
|
||||
type=NodeType.HTTP_REQUEST,
|
||||
title="http",
|
||||
desc="",
|
||||
url="http://example.com",
|
||||
|
||||
@ -173,7 +173,7 @@ class TestWebhookService:
|
||||
assert workflow.app_id == test_data["app"].id
|
||||
assert node_config is not None
|
||||
assert node_config["id"] == "webhook_node"
|
||||
assert node_config["data"]["title"] == "Test Webhook"
|
||||
assert node_config["data"].title == "Test Webhook"
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers):
|
||||
"""Test webhook trigger not found scenario."""
|
||||
|
||||
@ -25,7 +25,8 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing
|
||||
|
||||
# load dotenv file with pydantic-settings
|
||||
config = DifyConfig()
|
||||
# Disable `.env` loading to ensure test stability across environments
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
# constant values
|
||||
assert config.COMMIT_SHA == ""
|
||||
@ -59,7 +60,8 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
|
||||
config = DifyConfig()
|
||||
# Disable `.env` loading to ensure test stability across environments
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
# Verify default timeout values
|
||||
assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10
|
||||
@ -86,7 +88,8 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("WEB_API_CORS_ALLOW_ORIGINS", "http://127.0.0.1:3000,*")
|
||||
monkeypatch.setenv("CODE_EXECUTION_ENDPOINT", "http://127.0.0.1:8194/")
|
||||
|
||||
flask_app.config.from_mapping(DifyConfig().model_dump()) # pyright: ignore
|
||||
# Disable `.env` loading to ensure test stability across environments
|
||||
flask_app.config.from_mapping(DifyConfig(_env_file=None).model_dump()) # pyright: ignore
|
||||
config = flask_app.config
|
||||
|
||||
# configs read from pydantic-settings
|
||||
|
||||
@ -51,7 +51,7 @@ def bypass_decorators(mocker):
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check",
|
||||
return_value=lambda *_: (lambda f: f),
|
||||
return_value=lambda *_: lambda f: f,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@ from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import SchedulingPause
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -22,7 +24,7 @@ from dify_graph.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig
|
||||
from dify_graph.nodes.base.entities import OutputVariableEntity
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.end.entities import EndNodeData
|
||||
from dify_graph.nodes.start.entities import StartNodeData
|
||||
@ -42,6 +44,7 @@ if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
|
||||
|
||||
class _StubToolNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.TOOL
|
||||
pause_on: bool = False
|
||||
|
||||
|
||||
@ -88,16 +91,17 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
def _patch_tool_node(mocker):
|
||||
original_create_node = DifyNodeFactory.create_node
|
||||
|
||||
def _patched_create_node(self, node_config: dict[str, object]) -> Node:
|
||||
node_data = node_config.get("data", {})
|
||||
if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value:
|
||||
def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
node_data = typed_node_config["data"]
|
||||
if node_data.type == NodeType.TOOL:
|
||||
return _StubToolNode(
|
||||
id=str(node_config["id"]),
|
||||
config=node_config,
|
||||
id=str(typed_node_config["id"]),
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
return original_create_node(self, node_config)
|
||||
return original_create_node(self, typed_node_config)
|
||||
|
||||
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
|
||||
|
||||
|
||||
@ -7,7 +7,9 @@ import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from models.workflow import Workflow
|
||||
@ -105,3 +107,57 @@ def test_run_uses_single_node_execution_branch(
|
||||
assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER
|
||||
assert entry_kwargs["variable_pool"] is variable_pool
|
||||
assert entry_kwargs["graph_runtime_state"] is graph_runtime_state
|
||||
|
||||
|
||||
def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
runner = WorkflowBasedAppRunner(
|
||||
queue_manager=MagicMock(spec=AppQueueManager),
|
||||
variable_loader=MagicMock(),
|
||||
app_id="app",
|
||||
)
|
||||
|
||||
workflow = MagicMock(spec=Workflow)
|
||||
workflow.id = "workflow"
|
||||
workflow.tenant_id = "tenant"
|
||||
workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "loop-node",
|
||||
"data": {
|
||||
"type": "loop",
|
||||
"title": "Loop",
|
||||
"loop_count": 1,
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
_, _, graph_runtime_state = _make_graph_state()
|
||||
seen_configs: list[object] = []
|
||||
original_validate_python = NodeConfigDictAdapter.validate_python
|
||||
|
||||
def record_validate_python(value: object):
|
||||
seen_configs.append(value)
|
||||
return original_validate_python(value)
|
||||
|
||||
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
|
||||
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
|
||||
):
|
||||
runner._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id="loop-node",
|
||||
user_inputs={},
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
assert seen_configs == [workflow.graph_dict["nodes"][0]]
|
||||
|
||||
@ -7,10 +7,10 @@ from dataclasses import dataclass
|
||||
import pytest
|
||||
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph.validation import GraphValidationError
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@ -183,3 +183,36 @@ def test_graph_validation_blocks_start_and_trigger_coexistence(
|
||||
Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues)
|
||||
|
||||
|
||||
def test_graph_init_ignores_custom_note_nodes_before_node_data_validation(
|
||||
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||
) -> None:
|
||||
node_factory, graph_config = graph_init_dependencies
|
||||
graph_config["nodes"] = [
|
||||
{
|
||||
"id": "start",
|
||||
"data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT},
|
||||
},
|
||||
{"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}},
|
||||
{
|
||||
"id": "note",
|
||||
"type": "custom-note",
|
||||
"data": {
|
||||
"type": "",
|
||||
"title": "",
|
||||
"desc": "",
|
||||
"text": "{}",
|
||||
"theme": "blue",
|
||||
},
|
||||
},
|
||||
]
|
||||
graph_config["edges"] = [
|
||||
{"source": "start", "target": "answer", "sourceHandle": "success"},
|
||||
]
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
assert graph.root_node.id == "start"
|
||||
assert "answer" in graph.nodes
|
||||
assert "note" not in graph.nodes
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dify_graph.entities.base_node_data import RetryConfig
|
||||
from dify_graph.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.domain.graph_execution import GraphExecution
|
||||
@ -12,7 +13,6 @@ from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
|
||||
from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.entities import RetryConfig
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import time
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType
|
||||
from dify_graph.enums import ErrorStrategy
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
@ -18,7 +19,6 @@ from dify_graph.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.nodes.base.entities import DefaultValue, DefaultValueType
|
||||
|
||||
# Import the test framework from the new module
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
|
||||
@ -5,10 +5,10 @@ This module provides a MockNodeFactory that automatically detects and mocks node
|
||||
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
||||
@ -75,39 +75,27 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
NodeType.CODE: MockCodeNode,
|
||||
}
|
||||
|
||||
def create_node(self, node_config: Mapping[str, Any]) -> Node:
|
||||
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a node instance, using mock implementations for third-party service nodes.
|
||||
|
||||
:param node_config: Node configuration dictionary
|
||||
:return: Node instance (real or mocked)
|
||||
"""
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
node_type_str = node_data.get("type")
|
||||
|
||||
if not node_type_str:
|
||||
# Fall back to parent implementation for nodes without type
|
||||
return super().create_node(node_config)
|
||||
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
# Unknown node type, use parent implementation
|
||||
return super().create_node(node_config)
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
node_data = typed_node_config["data"]
|
||||
node_type = node_data.type
|
||||
|
||||
# Check if this node type should be mocked
|
||||
if node_type in self._mock_node_types:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node config missing id")
|
||||
node_id = typed_node_config["id"]
|
||||
|
||||
# Create mock node instance
|
||||
mock_class = self._mock_node_types[node_type]
|
||||
if node_type == NodeType.CODE:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -117,7 +105,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
elif node_type == NodeType.HTTP_REQUEST:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -129,7 +117,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -139,7 +127,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
else:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -148,7 +136,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
return mock_instance
|
||||
|
||||
# For non-mocked node types, use parent implementation
|
||||
return super().create_node(node_config)
|
||||
return super().create_node(typed_node_config)
|
||||
|
||||
def should_mock_node(self, node_type: NodeType) -> bool:
|
||||
"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
||||
# Ensures that all node classes are imported.
|
||||
@ -126,3 +126,20 @@ def test_init_subclass_sets_node_data_type_from_generic():
|
||||
return "1"
|
||||
|
||||
assert _AutoNode._node_data_type is _TestNodeData
|
||||
|
||||
|
||||
def test_validate_node_data_uses_declared_node_data_type():
|
||||
"""Public validation should hydrate the subclass-declared node data model."""
|
||||
|
||||
class _AutoNode(Node[_TestNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
return "1"
|
||||
|
||||
base_node_data = BaseNodeData.model_validate({"type": NodeType.CODE, "title": "Test"})
|
||||
|
||||
validated = _AutoNode.validate_node_data(base_node_data)
|
||||
|
||||
assert isinstance(validated, _TestNodeData)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import types
|
||||
from collections.abc import Mapping
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
||||
# Import concrete nodes we will assert on (numeric version path)
|
||||
|
||||
@ -272,7 +272,7 @@ class TestCodeNodeExtractVariableSelector:
|
||||
result = CodeNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="node_1",
|
||||
node_data=node_data,
|
||||
node_data=CodeNodeData.model_validate(node_data, from_attributes=True),
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
@ -292,7 +292,7 @@ class TestCodeNodeExtractVariableSelector:
|
||||
result = CodeNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="node_1",
|
||||
node_data=node_data,
|
||||
node_data=CodeNodeData.model_validate(node_data, from_attributes=True),
|
||||
)
|
||||
|
||||
assert "node_1.input_text" in result
|
||||
@ -315,7 +315,7 @@ class TestCodeNodeExtractVariableSelector:
|
||||
result = CodeNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="code_node",
|
||||
node_data=node_data,
|
||||
node_data=CodeNodeData.model_validate(node_data, from_attributes=True),
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
@ -338,7 +338,7 @@ class TestCodeNodeExtractVariableSelector:
|
||||
result = CodeNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="node_x",
|
||||
node_data=node_data,
|
||||
node_data=CodeNodeData.model_validate(node_data, from_attributes=True),
|
||||
)
|
||||
|
||||
assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"]
|
||||
@ -437,7 +437,7 @@ class TestCodeNodeInitialization:
|
||||
"outputs": {"x": {"type": "number"}},
|
||||
}
|
||||
|
||||
node._node_data = node._hydrate_node_data(data)
|
||||
node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True)
|
||||
|
||||
assert node._node_data.title == "Test Node"
|
||||
assert node._node_data.code_language == CodeLanguage.PYTHON3
|
||||
@ -453,7 +453,7 @@ class TestCodeNodeInitialization:
|
||||
"outputs": {"x": {"type": "number"}},
|
||||
}
|
||||
|
||||
node._node_data = node._hydrate_node_data(data)
|
||||
node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True)
|
||||
|
||||
assert node._node_data.code_language == CodeLanguage.JAVASCRIPT
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from dify_graph.nodes.iteration.exc import (
|
||||
@ -388,3 +389,50 @@ class TestIterationNodeErrorStrategies:
|
||||
result = node._get_default_value_dict()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
|
||||
seen_configs: list[object] = []
|
||||
original_validate_python = NodeConfigDictAdapter.validate_python
|
||||
|
||||
def record_validate_python(value: object):
|
||||
seen_configs.append(value)
|
||||
return original_validate_python(value)
|
||||
|
||||
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
||||
|
||||
child_node_config = {
|
||||
"id": "answer-node",
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "Answer",
|
||||
"answer": "",
|
||||
"iteration_id": "iteration-node",
|
||||
},
|
||||
}
|
||||
|
||||
IterationNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "iteration-node",
|
||||
"data": {
|
||||
"type": "iteration",
|
||||
"title": "Iteration",
|
||||
"iterator_selector": ["start", "items"],
|
||||
"output_selector": ["iteration", "result"],
|
||||
},
|
||||
},
|
||||
child_node_config,
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
node_id="iteration-node",
|
||||
node_data=IterationNodeData(
|
||||
title="Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
output_selector=["iteration", "result"],
|
||||
),
|
||||
)
|
||||
|
||||
assert seen_configs == [child_node_config]
|
||||
|
||||
@ -410,14 +410,14 @@ class TestKnowledgeRetrievalNode:
|
||||
"""Test _extract_variable_selector_to_variable_mapping class method."""
|
||||
# Arrange
|
||||
node_id = "knowledge_node_1"
|
||||
node_data = {
|
||||
"type": "knowledge-retrieval",
|
||||
"title": "Knowledge Retrieval",
|
||||
"dataset_ids": [str(uuid.uuid4())],
|
||||
"retrieval_mode": "multiple",
|
||||
"query_variable_selector": ["start", "query"],
|
||||
"query_attachment_selector": ["start", "attachments"],
|
||||
}
|
||||
node_data = KnowledgeRetrievalNodeData(
|
||||
type="knowledge-retrieval",
|
||||
title="Knowledge Retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="multiple",
|
||||
query_variable_selector=["start", "query"],
|
||||
query_attachment_selector=["start", "attachments"],
|
||||
)
|
||||
graph_config = {}
|
||||
|
||||
# Act
|
||||
|
||||
@ -4,8 +4,9 @@ import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@ -40,13 +41,26 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def _build_node_config() -> NodeConfigDict:
|
||||
return NodeConfigDictAdapter.validate_python(
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": NodeType.ANSWER.value,
|
||||
"title": "Sample",
|
||||
"foo": "bar",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_node_hydrates_data_during_initialization():
|
||||
graph_config: dict[str, object] = {}
|
||||
init_params, runtime_state = _build_context(graph_config)
|
||||
|
||||
node = _SampleNode(
|
||||
id="node-1",
|
||||
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
|
||||
config=_build_node_config(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@ -72,7 +86,7 @@ def test_node_accepts_invoke_from_enum():
|
||||
|
||||
node = _SampleNode(
|
||||
id="node-1",
|
||||
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
|
||||
config=_build_node_config(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@ -99,3 +113,17 @@ def test_missing_generic_argument_raises_type_error():
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_base_node_data_keeps_dict_style_access_compatibility():
|
||||
node_data = _SampleNodeData.model_validate(
|
||||
{
|
||||
"type": NodeType.ANSWER.value,
|
||||
"title": "Sample",
|
||||
"foo": "bar",
|
||||
}
|
||||
)
|
||||
|
||||
assert node_data["foo"] == "bar"
|
||||
assert node_data.get("foo") == "bar"
|
||||
assert node_data.get("missing", "fallback") == "fallback"
|
||||
|
||||
52
api/tests/unit_tests/core/workflow/nodes/test_loop_node.py
Normal file
52
api/tests/unit_tests/core/workflow/nodes/test_loop_node.py
Normal file
@ -0,0 +1,52 @@
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.nodes.loop.entities import LoopNodeData
|
||||
from dify_graph.nodes.loop.loop_node import LoopNode
|
||||
|
||||
|
||||
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
|
||||
seen_configs: list[object] = []
|
||||
original_validate_python = NodeConfigDictAdapter.validate_python
|
||||
|
||||
def record_validate_python(value: object):
|
||||
seen_configs.append(value)
|
||||
return original_validate_python(value)
|
||||
|
||||
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
||||
|
||||
child_node_config = {
|
||||
"id": "answer-node",
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "Answer",
|
||||
"answer": "",
|
||||
"loop_id": "loop-node",
|
||||
},
|
||||
}
|
||||
|
||||
LoopNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "loop-node",
|
||||
"data": {
|
||||
"type": "loop",
|
||||
"title": "Loop",
|
||||
"loop_count": 1,
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
},
|
||||
},
|
||||
child_node_config,
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
node_id="loop-node",
|
||||
node_data=LoopNodeData(
|
||||
title="Loop",
|
||||
loop_count=1,
|
||||
break_conditions=[],
|
||||
logical_operator="and",
|
||||
),
|
||||
)
|
||||
|
||||
assert seen_configs == [child_node_config]
|
||||
@ -210,9 +210,6 @@ def test_webhook_data_model_dump_with_alias():
|
||||
|
||||
def test_webhook_data_validation_errors():
|
||||
"""Test WebhookData validation errors."""
|
||||
# Title is required (inherited from BaseNodeData)
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData()
|
||||
|
||||
# Invalid method
|
||||
with pytest.raises(ValidationError):
|
||||
@ -254,6 +251,36 @@ def test_webhook_data_sequence_fields():
|
||||
assert len(data.headers) == 1 # Should still be 1
|
||||
|
||||
|
||||
def test_webhook_data_rejects_non_string_header_types():
|
||||
"""Headers should stay string-only because runtime does not coerce header values."""
|
||||
for param_type in ["number", "boolean", "object", "array[string]", "file"]:
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData(
|
||||
title="Test",
|
||||
headers=[WebhookParameter(name="X-Test", type=param_type)],
|
||||
)
|
||||
|
||||
|
||||
def test_webhook_data_limits_query_param_types_to_scalar_values():
|
||||
"""Query params only support scalar conversions in the current runtime."""
|
||||
data = WebhookData(
|
||||
title="Test",
|
||||
params=[
|
||||
WebhookParameter(name="count", type="number"),
|
||||
WebhookParameter(name="enabled", type="boolean"),
|
||||
],
|
||||
)
|
||||
assert data.params[0].type == "number"
|
||||
assert data.params[1].type == "boolean"
|
||||
|
||||
for param_type in ["object", "array[string]", "array[number]", "array[boolean]", "array[object]", "file"]:
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData(
|
||||
title="Test",
|
||||
params=[WebhookParameter(name="test", type=param_type)],
|
||||
)
|
||||
|
||||
|
||||
def test_webhook_data_sync_mode():
|
||||
"""Test WebhookData SyncMode nested enum."""
|
||||
# Test that SyncMode enum exists and has expected value
|
||||
@ -297,7 +324,7 @@ def test_webhook_body_parameter_edge_cases():
|
||||
|
||||
def test_webhook_data_inheritance():
|
||||
"""Test WebhookData inherits from BaseNodeData correctly."""
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
|
||||
# Test that WebhookData is a subclass of BaseNodeData
|
||||
assert issubclass(WebhookData, BaseNodeData)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.nodes.base.exc import BaseNodeError
|
||||
from dify_graph.entities.exc import BaseNodeError
|
||||
from dify_graph.nodes.trigger_webhook.exc import (
|
||||
WebhookConfigError,
|
||||
WebhookNodeError,
|
||||
|
||||
82
api/tests/unit_tests/core/workflow/test_node_factory.py
Normal file
82
api/tests/unit_tests/core/workflow/test_node_factory.py
Normal file
@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _build_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
|
||||
graph_init_params = build_test_graph_init_params(
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
|
||||
def test_create_node_uses_declared_node_data_type_for_llm_validation(monkeypatch):
|
||||
class _FactoryLLMNodeData(LLMNodeData):
|
||||
pass
|
||||
|
||||
llm_node_config = {
|
||||
"id": "llm-node",
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "LLM",
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o-mini",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"prompt_template": [],
|
||||
"context": {
|
||||
"enabled": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
graph_config = {"nodes": [llm_node_config], "edges": []}
|
||||
factory = _build_factory(graph_config)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
monkeypatch.setattr(LLMNode, "_node_data_type", _FactoryLLMNodeData)
|
||||
|
||||
def _capture_model_instance(self: DifyNodeFactory, node_data: object) -> ModelInstance:
|
||||
captured["node_data"] = node_data
|
||||
return object() # type: ignore[return-value]
|
||||
|
||||
def _capture_memory(
|
||||
self: DifyNodeFactory,
|
||||
*,
|
||||
node_data: object,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
captured["memory_node_data"] = node_data
|
||||
|
||||
monkeypatch.setattr(DifyNodeFactory, "_build_model_instance_for_llm_node", _capture_model_instance)
|
||||
monkeypatch.setattr(DifyNodeFactory, "_build_memory_for_llm_node", _capture_memory)
|
||||
|
||||
node = factory.create_node(llm_node_config)
|
||||
|
||||
assert isinstance(captured["node_data"], _FactoryLLMNodeData)
|
||||
assert isinstance(captured["memory_node_data"], _FactoryLLMNodeData)
|
||||
assert isinstance(node.node_data, _FactoryLLMNodeData)
|
||||
@ -9,6 +9,7 @@ from dify_graph.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.file.enums import FileType
|
||||
from dify_graph.file.models import File, FileTransferMethod
|
||||
from dify_graph.nodes.code.code_node import CodeNode
|
||||
@ -124,7 +125,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def get_node_config_by_id(self, target_id: str):
|
||||
assert target_id == node_id
|
||||
return node_config
|
||||
return NodeConfigDictAdapter.validate_python(node_config)
|
||||
|
||||
workflow = StubWorkflow()
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={})
|
||||
|
||||
@ -258,38 +258,38 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
|
||||
return q
|
||||
|
||||
msg_session_1 = MagicMock()
|
||||
msg_session_1.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
|
||||
msg_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_1.commit.return_value = None
|
||||
|
||||
msg_session_2 = MagicMock()
|
||||
msg_session_2.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
|
||||
msg_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_2.commit.return_value = None
|
||||
|
||||
conv_session_1 = MagicMock()
|
||||
conv_session_1.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
|
||||
conv_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_1.commit.return_value = None
|
||||
|
||||
conv_session_2 = MagicMock()
|
||||
conv_session_2.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
|
||||
conv_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_2.commit.return_value = None
|
||||
|
||||
wal_session_1 = MagicMock()
|
||||
wal_session_1.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
wal_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_1.commit.return_value = None
|
||||
|
||||
wal_session_2 = MagicMock()
|
||||
wal_session_2.query.side_effect = (
|
||||
lambda model: make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
wal_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_2.commit.return_value = None
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
@ -22,7 +23,7 @@ def _make_service() -> WorkflowService:
|
||||
return WorkflowService(session_maker=sessionmaker())
|
||||
|
||||
|
||||
def _build_node_config(delivery_methods):
|
||||
def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict:
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
delivery_methods=delivery_methods,
|
||||
@ -31,7 +32,7 @@ def _build_node_config(delivery_methods):
|
||||
user_actions=[],
|
||||
).model_dump(mode="json")
|
||||
node_data["type"] = NodeType.HUMAN_INPUT.value
|
||||
return {"id": "node-1", "data": node_data}
|
||||
return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data})
|
||||
|
||||
|
||||
def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod:
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction
|
||||
from dify_graph.nodes.human_input.enums import FormInputType
|
||||
@ -187,7 +188,10 @@ class TestWorkflowService:
|
||||
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
node_config = NodeConfigDictAdapter.validate_python(
|
||||
{"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
)
|
||||
workflow.get_node_config_by_id.return_value = node_config
|
||||
workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
|
||||
@ -232,7 +236,7 @@ class TestWorkflowService:
|
||||
service._build_human_input_variable_pool.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}},
|
||||
node_config=node_config,
|
||||
manual_inputs={"#node-0.result#": "LLM output"},
|
||||
)
|
||||
|
||||
@ -267,7 +271,9 @@ class TestWorkflowService:
|
||||
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
|
||||
{"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
)
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
659
api/uv.lock
generated
659
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@ -47,7 +47,7 @@ const SettingRow = memo(({
|
||||
: 'bg-workflow-block-parma-bg',
|
||||
)}
|
||||
>
|
||||
<div className={cn('mr-2 shrink-0 system-xs-medium-uppercase', warning ? 'text-text-warning' : 'text-text-tertiary')}>
|
||||
<div className="mr-2 shrink-0 text-text-tertiary system-xs-medium-uppercase">
|
||||
{label}
|
||||
</div>
|
||||
<div
|
||||
@ -141,13 +141,16 @@ const Node: FC<NodeProps<KnowledgeBaseNodeType>> = ({ data }) => {
|
||||
if (data.indexing_technique !== IndexMethodEnum.QUALIFIED)
|
||||
return '-'
|
||||
|
||||
if (isKnowledgeBaseEmbeddingIssue(validationIssue))
|
||||
if (isKnowledgeBaseEmbeddingIssue(validationIssue)) {
|
||||
if (validationIssue?.code === KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
|
||||
return t('nodes.knowledgeBase.notConfigured', { ns: 'workflow' })
|
||||
return validationIssueMessage
|
||||
}
|
||||
|
||||
const currentEmbeddingModelProvider = embeddingModelList.find(provider => provider.provider === data.embedding_model_provider)
|
||||
const currentEmbeddingModel = currentEmbeddingModelProvider?.models.find(model => model.model === data.embedding_model)
|
||||
return currentEmbeddingModel?.label[language] || currentEmbeddingModel?.label.en_US || data.embedding_model || '-'
|
||||
}, [data.embedding_model, data.embedding_model_provider, data.indexing_technique, embeddingModelList, language, validationIssue, validationIssueMessage])
|
||||
}, [data.embedding_model, data.embedding_model_provider, data.indexing_technique, embeddingModelList, language, validationIssue, validationIssueMessage, t])
|
||||
|
||||
const indexMethodDisplay = settingsDisplay[data.indexing_technique as keyof typeof settingsDisplay] || '-'
|
||||
const retrievalMethodDisplay = settingsDisplay[data.retrieval_model?.search_method as keyof typeof settingsDisplay] || '-'
|
||||
|
||||
@ -687,6 +687,7 @@
|
||||
"nodes.knowledgeBase.embeddingModelIsRequired": "Embedding model is required",
|
||||
"nodes.knowledgeBase.embeddingModelNotConfigured": "Embedding model not configured",
|
||||
"nodes.knowledgeBase.indexMethodIsRequired": "Index method is required",
|
||||
"nodes.knowledgeBase.notConfigured": "Not configured",
|
||||
"nodes.knowledgeBase.rerankingModelIsInvalid": "Reranking model is invalid",
|
||||
"nodes.knowledgeBase.rerankingModelIsRequired": "Reranking model is required",
|
||||
"nodes.knowledgeBase.retrievalSettingIsRequired": "Retrieval setting is required",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user