Compare commits

...

21 Commits

Author SHA1 Message Date
b7a5ed6c0b test(api): cover remaining workflow typing branches 2026-03-25 19:47:41 +08:00
e819a9a5f7 Apply suggestions from code review
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-25 19:45:36 +08:00
bc82676d93 Update api/dify_graph/nodes/loop/loop_node.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-25 19:29:05 +08:00
7b76fdc1d3 test(api): cover workflow typing paths 2026-03-25 19:06:23 +08:00
82acddddb4 Merge remote-tracking branch 'origin/main' into yanli/phase3-code-scope 2026-03-25 18:13:16 +08:00
710ac3b90a fix(api): preserve typed loop array constants 2026-03-18 22:05:16 +08:00
8548498f25 fix(api): restore advanced chat refresh_model contract 2026-03-18 19:41:00 +08:00
d014f0b91a fix(api): address typing review feedback 2026-03-18 19:16:48 +08:00
cc5aac268a fix(api): support tool typed dicts on py311 2026-03-18 18:59:49 +08:00
4c1d27431b fix(api): restore workflow node compatibility 2026-03-18 18:43:35 +08:00
9a86f280eb fix(api): avoid recursive loop type adapters 2026-03-18 18:20:43 +08:00
c5920fb28a Merge remote-tracking branch 'origin/main' into yanli/phase3-code-scope 2026-03-18 17:52:03 +08:00
2f81d5dfdf fix(api): restore typedict py311 compatibility 2026-03-17 20:30:18 +08:00
7639d8e43f fix(api): reuse advanced chat refresh session 2026-03-17 20:18:21 +08:00
1dce81c604 refactor(api): type single node workflow helpers 2026-03-17 20:16:14 +08:00
f874ca183e chore(api): remove phase 3 pyrefly excludes 2026-03-17 20:04:55 +08:00
0d805e624e Type phase 3 loop values 2026-03-17 19:39:54 +08:00
61196180b8 Type phase 3 tool inputs 2026-03-17 19:31:00 +08:00
79433b0091 Refine phase 3 typing boundaries 2026-03-17 19:13:12 +08:00
c4aeaa35d4 Type phase 3 schema contracts 2026-03-17 18:56:22 +08:00
9f0d79b8b0 Tighten phase 3 runtime typing 2026-03-17 18:49:14 +08:00
42 changed files with 1480 additions and 375 deletions

View File

@ -5,7 +5,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
@ -47,7 +47,6 @@ from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import ( from services.workflow_draft_variable_service import (
@ -522,8 +521,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# release database connection, because the following new thread operations may take a long time # release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session: with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow) workflow = _refresh_model(session=session, model=workflow)
message = _refresh_model(session, message) message = _refresh_model(session=session, model=message)
if message is None:
raise RuntimeError("Failed to refresh Message; _refresh_model returned None.")
# workflow_ = session.get(Workflow, workflow.id) # workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None # assert workflow_ is not None
# workflow = workflow_ # workflow = workflow_
@ -690,11 +691,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
raise e raise e
_T = TypeVar("_T", bound=Base) @overload
def _refresh_model(*, session: Session | None = None, model: Workflow) -> Workflow: ...
def _refresh_model(session, model: _T) -> _T: @overload
with Session(bind=db.engine, expire_on_commit=False) as session: def _refresh_model(*, session: Session | None = None, model: Message) -> Message: ...
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model def _refresh_model(*, session: Session | None = None, model: Any) -> Any:
if session is not None:
detached_model = session.get(type(model), model.id)
assert detached_model is not None
return detached_model
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
detached_model = refresh_session.get(type(model), model.id)
assert detached_model is not None
return detached_model

View File

@ -1,4 +1,4 @@
from collections.abc import Generator from collections.abc import Generator, Iterator
from typing import Any, cast from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -56,8 +56,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, Any, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -87,8 +87,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, Any, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response

View File

@ -1,4 +1,4 @@
from collections.abc import Generator from collections.abc import Generator, Iterator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -1,7 +1,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping from collections.abc import Generator, Iterator, Mapping
from typing import Any, Union from typing import Any
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -16,24 +16,26 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
def convert( def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)
else: else:
stream_response = response
def _generate_full_response() -> Generator[dict | str, Any, None]: def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_full_response(response) yield from cls.convert_stream_full_response(stream_response)
return _generate_full_response() return _generate_full_response()
else: else:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response) return cls.convert_blocking_simple_response(response)
else: else:
stream_response = response
def _generate_simple_response() -> Generator[dict | str, Any, None]: def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_simple_response(response) yield from cls.convert_stream_simple_response(stream_response)
return _generate_simple_response() return _generate_simple_response()
@ -50,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
raise NotImplementedError raise NotImplementedError

View File

@ -224,6 +224,7 @@ class BaseAppGenerator:
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory: def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER: if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account) assert isinstance(account, Account)
debug_account = account
def draft_var_saver_factory( def draft_var_saver_factory(
session: Session, session: Session,
@ -240,7 +241,7 @@ class BaseAppGenerator:
node_type=node_type, node_type=node_type,
node_execution_id=node_execution_id, node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id, enclosing_node_id=enclosing_node_id,
user=account, user=debug_account,
) )
else: else:

View File

@ -166,15 +166,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# init generate records # init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation) (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if conversation is None or message is None:
raise RuntimeError("_init_generate_records() returned None for conversation or message")
generated_conversation_id = str(conversation.id)
generated_message_id = str(message.id)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id, task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=generated_conversation_id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id, message_id=generated_message_id,
) )
# new thread with request context # new thread with request context
@ -184,8 +188,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
conversation_id=conversation.id, conversation_id=generated_conversation_id,
message_id=message.id, message_id=generated_message_id,
) )
worker_thread = threading.Thread(target=worker_with_context) worker_thread = threading.Thread(target=worker_with_context)

View File

@ -1,4 +1,4 @@
from collections.abc import Generator from collections.abc import Generator, Iterator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -149,6 +149,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records # init generate records
(conversation, message) = self._init_generate_records(application_generate_entity) (conversation, message) = self._init_generate_records(application_generate_entity)
if conversation is None or message is None:
raise RuntimeError(
"_init_generate_records() returned None for conversation or message, "
"which is required to proceed with generation."
)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -312,15 +317,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records # init generate records
(conversation, message) = self._init_generate_records(application_generate_entity) (conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
conversation_id = str(conversation.id)
message_id = str(message.id)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id, task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation_id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id, message_id=message_id,
) )
# new thread with request context # new thread with request context
@ -330,7 +339,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
message_id=message.id, message_id=message_id,
) )
worker_thread = threading.Thread(target=worker_with_context) worker_thread = threading.Thread(target=worker_with_context)

View File

@ -1,4 +1,4 @@
from collections.abc import Generator from collections.abc import Generator, Iterator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -1,4 +1,4 @@
from collections.abc import Generator from collections.abc import Generator, Iterator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -1,4 +1,4 @@
from collections.abc import Generator from collections.abc import Generator, Iterator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -1,13 +1,17 @@
import logging import logging
import time import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, cast from typing import Protocol, TypeAlias
from pydantic import ValidationError from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.entities.app_invoke_entities import (
InvokeFrom,
UserFrom,
build_dify_run_context,
)
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
QueueAgentLogEvent, QueueAgentLogEvent,
@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.graph import Graph from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_engine.layers.base import GraphEngineLayer
@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GraphConfigObject: TypeAlias = dict[str, object]
GraphConfigMapping: TypeAlias = Mapping[str, object]
class SingleNodeRunEntity(Protocol):
node_id: str
inputs: Mapping[str, object]
class WorkflowBasedAppRunner: class WorkflowBasedAppRunner:
def __init__( def __init__(
@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
def _init_graph( def _init_graph(
self, self,
graph_config: Mapping[str, Any], graph_config: GraphConfigMapping,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
user_from: UserFrom, user_from: UserFrom,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
def _prepare_single_node_execution( def _prepare_single_node_execution(
self, self,
workflow: Workflow, workflow: Workflow,
single_iteration_run: Any | None = None, single_iteration_run: SingleNodeRunEntity | None = None,
single_loop_run: Any | None = None, single_loop_run: SingleNodeRunEntity | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]: ) -> tuple[Graph, VariablePool, GraphRuntimeState]:
""" """
Prepare graph, variable pool, and runtime state for single node execution Prepare graph, variable pool, and runtime state for single node execution
@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
# This ensures all nodes in the graph reference the same GraphRuntimeState instance # This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state return graph, variable_pool, graph_runtime_state
@staticmethod
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
nodes = graph_config.get("nodes")
edges = graph_config.get("edges")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(edges, list):
raise ValueError("edges in workflow graph must be a list")
validated_nodes: list[GraphConfigMapping] = []
for node in nodes:
if not isinstance(node, Mapping):
raise ValueError("nodes in workflow graph must be mappings")
validated_nodes.append(node)
validated_edges: list[GraphConfigMapping] = []
for edge in edges:
if not isinstance(edge, Mapping):
raise ValueError("edges in workflow graph must be mappings")
validated_edges.append(edge)
return validated_nodes, validated_edges
@staticmethod
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
if node_config is None:
return None
node_data = node_config.get("data")
if not isinstance(node_data, Mapping):
return None
start_node_id = node_data.get("start_node_id")
return start_node_id if isinstance(start_node_id, str) else None
@classmethod
def _build_single_node_graph_config(
cls,
*,
graph_config: GraphConfigMapping,
node_id: str,
node_type_filter_key: str,
) -> tuple[GraphConfigObject, NodeConfigDict]:
node_configs, edge_configs = cls._get_graph_items(graph_config)
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
start_node_id = cls._extract_start_node_id(main_node_config)
filtered_node_configs = [
dict(node)
for node in node_configs
if node.get("id") == node_id
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
or (start_node_id and node.get("id") == start_node_id)
]
if not filtered_node_configs:
raise ValueError(f"node id {node_id} not found in workflow graph")
filtered_node_ids = {
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
}
filtered_edge_configs = [
dict(edge)
for edge in edge_configs
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
]
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
if target_node_config is None:
raise ValueError(f"node id {node_id} not found in workflow graph")
return (
{
"nodes": filtered_node_configs,
"edges": filtered_edge_configs,
},
NodeConfigDictAdapter.validate_python(target_node_config),
)
def _get_graph_and_variable_pool_for_single_node_run( def _get_graph_and_variable_pool_for_single_node_run(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict[str, Any], user_inputs: Mapping[str, object],
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id' node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages node_type_label: str = "node", # 'iteration' or 'loop' for error messages
@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
if not graph_config: if not graph_config:
raise ValueError("workflow graph not found") raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config: if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph") raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list): graph_config, target_node_config = self._build_single_node_graph_config(
raise ValueError("nodes in workflow graph must be a list") graph_config=graph_config,
node_id=node_id,
if not isinstance(graph_config.get("edges"), list): node_type_filter_key=node_type_filter_key,
raise ValueError("edges in workflow graph must be a list") )
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init # Create required parameters for Graph.init
graph_init_params = GraphInitParams( graph_init_params = GraphInitParams(
@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
if not graph: if not graph:
raise ValueError("graph not found in workflow") raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class # Get node class
node_type = target_node_config["data"].type node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version) node_version = str(target_node_config["data"].version)

View File

@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
""" """
node_id: str node_id: str
inputs: Mapping inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None single_iteration_run: SingleIterationRunEntity | None = None
@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
""" """
node_id: str node_id: str
inputs: Mapping inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None single_loop_run: SingleLoopRunEntity | None = None
@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
""" """
node_id: str node_id: str
inputs: dict inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None single_iteration_run: SingleIterationRunEntity | None = None
@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
""" """
node_id: str node_id: str
inputs: dict inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None single_loop_run: SingleLoopRunEntity | None = None

View File

@ -1045,9 +1045,10 @@ class ToolManager:
continue continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable": if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value) variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None: if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist") raise ToolParameterError(f"Variable {variable_selector} does not exist")
parameter_value = variable.value parameter_value = variable.value
elif tool_input.type == "constant": elif tool_input.type == "constant":
parameter_value = tool_input.value parameter_value = tool_input.value

View File

@ -1,13 +1,24 @@
from enum import IntEnum, StrEnum, auto from __future__ import annotations
from typing import Any, Literal, Union
from pydantic import BaseModel from enum import IntEnum, StrEnum, auto
from typing import Literal, TypeAlias
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.enums import BuiltinNodeTypes, NodeType
AgentInputConstantValue: TypeAlias = (
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
)
VariableSelector: TypeAlias = list[str]
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class AgentNodeData(BaseNodeData): class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT type: NodeType = BuiltinNodeTypes.AGENT
@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
tool_node_version: str | None = None tool_node_version: str | None = None
class AgentInput(BaseModel): class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"] type: Literal["mixed", "variable", "constant"]
value: AgentInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> AgentInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "variable":
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type in {"mixed", "constant"}:
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
raise ValueError(f"Unknown agent input type: {input_type}")
agent_parameters: dict[str, AgentInput] agent_parameters: dict[str, AgentInput]

View File

@ -1,16 +1,17 @@
from __future__ import annotations from __future__ import annotations
import json import json
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from typing import Any, cast from typing import TypeAlias
from packaging.version import Version from packaging.version import Version
from pydantic import ValidationError from pydantic import TypeAdapter, ValidationError
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyParameter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials from core.plugin.entities.request import InvokeCredentials
@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy from .strategy_protocols import ResolvedAgentStrategy
JsonObject: TypeAlias = dict[str, object]
JsonObjectList: TypeAlias = list[JsonObject]
VariableSelector: TypeAlias = list[str]
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
class AgentRuntimeSupport: class AgentRuntimeSupport:
def build_parameters( def build_parameters(
@ -39,12 +48,12 @@ class AgentRuntimeSupport:
strategy: ResolvedAgentStrategy, strategy: ResolvedAgentStrategy,
tenant_id: str, tenant_id: str,
app_id: str, app_id: str,
invoke_from: Any, invoke_from: InvokeFrom,
for_log: bool = False, for_log: bool = False,
) -> dict[str, Any]: ) -> dict[str, object]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {} result: dict[str, object] = {}
for parameter_name in node_data.agent_parameters: for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name) parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter: if not parameter:
@ -54,9 +63,10 @@ class AgentRuntimeSupport:
agent_input = node_data.agent_parameters[parameter_name] agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type: match agent_input.type:
case "variable": case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type] variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
variable = variable_pool.get(variable_selector)
if variable is None: if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value)) raise AgentVariableNotFoundError(str(variable_selector))
parameter_value = variable.value parameter_value = variable.value
case "mixed" | "constant": case "mixed" | "constant":
try: try:
@ -79,60 +89,38 @@ class AgentRuntimeSupport:
value = parameter_value value = parameter_value
if parameter.type == "array[tools]": if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value) tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
value = [tool for tool in value if tool.get("enabled", False)] value = self._normalize_tool_payloads(
value = self._filter_mcp_type_tool(strategy, value) strategy=strategy,
for tool in value: tools=tool_payloads,
if "schemas" in tool: variable_pool=variable_pool,
tool.pop("schemas") )
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log: if not for_log:
if parameter.type == "array[tools]": if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value) value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
tool_value = [] tool_value = []
for tool in value: for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) provider_type = self._coerce_tool_provider_type(tool.get("type"))
setting_params = tool.get("settings", {}) setting_params = self._coerce_json_object(tool.get("settings")) or {}
parameters = tool.get("parameters", {}) parameters = self._coerce_json_object(tool.get("parameters")) or {}
manual_input_params = [key for key, value in parameters.items() if value is not None] manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params} parameters = {**parameters, **setting_params}
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
credential_id = self._coerce_optional_string(tool.get("credential_id"))
entity = AgentToolEntity( entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""), provider_id=provider_id,
provider_type=provider_type, provider_type=provider_type,
tool_name=tool.get("tool_name", ""), tool_name=tool_name,
tool_parameters=parameters, tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None), plugin_unique_identifier=plugin_unique_identifier,
credential_id=tool.get("credential_id", None), credential_id=credential_id,
) )
extra = tool.get("extra", {}) extra = self._coerce_json_object(tool.get("extra")) or {}
runtime_variable_pool: VariablePool | None = None runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None: if node_data.version != "1" or node_data.tool_node_version is not None:
@ -145,8 +133,9 @@ class AgentRuntimeSupport:
runtime_variable_pool, runtime_variable_pool,
) )
if tool_runtime.entity.description: if tool_runtime.entity.description:
description_override = self._coerce_optional_string(extra.get("description"))
tool_runtime.entity.description.llm = ( tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm description_override or tool_runtime.entity.description.llm
) )
for tool_runtime_params in tool_runtime.entity.parameters: for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = ( tool_runtime_params.form = (
@ -167,13 +156,13 @@ class AgentRuntimeSupport:
{ {
**tool_runtime.entity.model_dump(mode="json"), **tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters, "runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None), "credential_id": credential_id,
"provider_type": provider_type.value, "provider_type": provider_type.value,
} }
) )
value = tool_value value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value) value = _JSON_OBJECT_ADAPTER.validate_python(value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = [] history_prompt_messages = []
if node_data.memory: if node_data.memory:
@ -199,17 +188,27 @@ class AgentRuntimeSupport:
return result return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials: def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
credentials = InvokeCredentials() credentials = InvokeCredentials()
credentials.tool_credentials = {} credentials.tool_credentials = {}
for tool in parameters.get("tools", []): tools = parameters.get("tools")
if not isinstance(tools, list):
return credentials
for raw_tool in tools:
tool = self._coerce_json_object(raw_tool)
if tool is None:
continue
if not tool.get("credential_id"): if not tool.get("credential_id"):
continue continue
try: try:
identity = ToolIdentity.model_validate(tool.get("identity", {})) identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError: except ValidationError:
continue continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) credential_id = self._coerce_optional_string(tool.get("credential_id"))
if credential_id is None:
continue
credentials.tool_credentials[identity.provider] = credential_id
return credentials return credentials
def fetch_memory( def fetch_memory(
@ -232,14 +231,14 @@ class AgentRuntimeSupport:
return TokenBufferMemory(conversation=conversation, model_instance=model_instance) return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager() provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle( provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=value.get("provider", ""), provider=str(value.get("provider", "")),
model_type=ModelType.LLM, model_type=ModelType.LLM,
) )
model_name = value.get("model", "") model_name = str(value.get("model", ""))
model_credentials = provider_model_bundle.configuration.get_current_credentials( model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model_type=ModelType.LLM,
model=model_name, model=model_name,
@ -249,7 +248,7 @@ class AgentRuntimeSupport:
model_instance = ModelManager().get_model_instance( model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider_name, provider=provider_name,
model_type=ModelType(value.get("model_type", "")), model_type=ModelType(str(value.get("model_type", ""))),
model=model_name, model=model_name,
) )
model_schema = model_type_instance.get_model_schema(model_name, model_credentials) model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
@ -268,9 +267,88 @@ class AgentRuntimeSupport:
@staticmethod @staticmethod
def _filter_mcp_type_tool( def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy, strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]], tools: JsonObjectList,
) -> list[dict[str, Any]]: ) -> JsonObjectList:
meta_version = strategy.meta_version meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"): if meta_version and Version(meta_version) > Version("0.0.1"):
return tools return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _normalize_tool_payloads(
self,
*,
strategy: ResolvedAgentStrategy,
tools: JsonObjectList,
variable_pool: VariablePool,
) -> JsonObjectList:
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
for tool in normalized_tools:
tool.pop("schemas", None)
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
tool["settings"] = self._resolve_tool_settings(tool)
return normalized_tools
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
if parameter_configs is None:
raw_parameters = self._coerce_json_object(tool.get("parameters"))
return raw_parameters or {}
resolved_parameters: JsonObject = {}
for key, parameter_config in parameter_configs.items():
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
value_param = self._coerce_json_object(parameter_config.get("value"))
if value_param and value_param.get("type") == "variable":
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
resolved_parameters[key] = variable.value
else:
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
else:
resolved_parameters[key] = None
return resolved_parameters
@staticmethod
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
if settings is None:
return {}
return {key: setting.get("value") for key, setting in settings.items()}
@staticmethod
def _coerce_json_object(value: object) -> JsonObject | None:
try:
return _JSON_OBJECT_ADAPTER.validate_python(value)
except ValidationError:
return None
@staticmethod
def _coerce_optional_string(value: object) -> str | None:
return value if isinstance(value, str) else None
@staticmethod
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
if isinstance(value, ToolProviderType):
return value
if isinstance(value, str):
return ToolProviderType(value)
return ToolProviderType.BUILT_IN
@classmethod
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
if not isinstance(value, dict):
return None
coerced: dict[str, JsonObject] = {}
for key, item in value.items():
if not isinstance(key, str):
return None
json_object = cls._coerce_json_object(item)
if json_object is None:
return None
coerced[key] = json_object
return coerced

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast from typing import Any, TypeAlias, cast
from configs import dify_config from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.exc import GenerateTaskStoppedError
@ -32,6 +32,13 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SpecialValueScalar: TypeAlias = str | int | float | bool | None
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
SerializedSpecialValue: TypeAlias = (
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
)
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
class _WorkflowChildEngineBuilder: class _WorkflowChildEngineBuilder:
@staticmethod @staticmethod
@ -276,10 +283,10 @@ class WorkflowEntry:
@staticmethod @staticmethod
def _create_single_node_graph( def _create_single_node_graph(
node_id: str, node_id: str,
node_data: dict[str, Any], node_data: Mapping[str, object],
node_width: int = 114, node_width: int = 114,
node_height: int = 514, node_height: int = 514,
) -> dict[str, Any]: ) -> SingleNodeGraphConfig:
""" """
Create a minimal graph structure for testing a single node in isolation. Create a minimal graph structure for testing a single node in isolation.
@ -289,14 +296,14 @@ class WorkflowEntry:
:param node_height: height for UI layout (default: 100) :param node_height: height for UI layout (default: 100)
:return: graph dictionary with start node and target node :return: graph dictionary with start node and target node
""" """
node_config = { node_config: dict[str, object] = {
"id": node_id, "id": node_id,
"width": node_width, "width": node_width,
"height": node_height, "height": node_height,
"type": "custom", "type": "custom",
"data": node_data, "data": dict(node_data),
} }
start_node_config = { start_node_config: dict[str, object] = {
"id": "start", "id": "start",
"width": node_width, "width": node_width,
"height": node_height, "height": node_height,
@ -321,7 +328,12 @@ class WorkflowEntry:
@classmethod @classmethod
def run_free_node( def run_free_node(
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] cls,
node_data: Mapping[str, object],
node_id: str,
tenant_id: str,
user_id: str,
user_inputs: Mapping[str, object],
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
""" """
Run free node Run free node
@ -339,6 +351,8 @@ class WorkflowEntry:
graph_dict = cls._create_single_node_graph(node_id, node_data) graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = node_data.get("type", "") node_type = node_data.get("type", "")
if not isinstance(node_type, str):
raise ValueError("Node type must be a string")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}: if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported") raise ValueError(f"Node type {node_type} not supported")
@ -369,7 +383,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state # init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
node_factory = DifyNodeFactory( node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
@ -405,30 +419,34 @@ class WorkflowEntry:
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod @staticmethod
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
# NOTE(QuantumGhost): Avoid using this function in new code. # NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict # Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain # immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information. # data integrity and type information.
result = WorkflowEntry._handle_special_values(value) result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result) if result is None:
return None
if isinstance(result, dict):
return result
raise TypeError("handle_special_values expects a mapping input")
@staticmethod @staticmethod
def _handle_special_values(value: Any): def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
if value is None: if value is None:
return value return value
if isinstance(value, dict): if isinstance(value, Mapping):
res = {} res: dict[str, SerializedSpecialValue] = {}
for k, v in value.items(): for k, v in value.items():
res[k] = WorkflowEntry._handle_special_values(v) res[k] = WorkflowEntry._handle_special_values(v)
return res return res
if isinstance(value, list): if isinstance(value, list):
res_list = [] res_list: list[SerializedSpecialValue] = []
for item in value: for item in value:
res_list.append(WorkflowEntry._handle_special_values(item)) res_list.append(WorkflowEntry._handle_special_values(item))
return res_list return res_list
if isinstance(value, File): if isinstance(value, File):
return value.to_dict() return dict(value.to_dict())
return value return value
@classmethod @classmethod

View File

@ -112,6 +112,8 @@ def _get_encoded_string(f: File, /) -> str:
data = _download_file_content(f.storage_key) data = _download_file_content(f.storage_key)
case FileTransferMethod.DATASOURCE_FILE: case FileTransferMethod.DATASOURCE_FILE:
data = _download_file_content(f.storage_key) data = _download_file_content(f.storage_key)
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
return base64.b64encode(data).decode("utf-8") return base64.b64encode(data).decode("utf-8")

View File

@ -133,6 +133,8 @@ class ExecutionLimitsLayer(GraphEngineLayer):
elif limit_type == LimitType.TIME_LIMIT: elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0 elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s" reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
else:
return
self.logger.warning("Execution limit exceeded: %s", reason) self.logger.warning("Execution limit exceeded: %s", reason)

View File

@ -336,12 +336,7 @@ class Node(Generic[NodeDataT]):
def _restore_execution_id_from_runtime_state(self) -> str | None: def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution graph_execution = self.graph_runtime_state.graph_execution
try: node_executions = graph_execution.node_executions
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_execution = node_executions.get(self._node_id) node_execution = node_executions.get(self._node_id)
if node_execution is None: if node_execution is None:
return None return None
@ -395,8 +390,7 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event) yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self.execution_id yield event.model_copy(update={"id": self.execution_id})
yield event
else: else:
yield event yield event
except Exception as e: except Exception as e:

View File

@ -443,7 +443,10 @@ def _extract_text_from_docx(file_content: bytes) -> str:
# Keep track of paragraph and table positions # Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = [] content_items: list[tuple[int, str, Table | Paragraph]] = []
it = iter(doc.element.body) doc_body = getattr(doc.element, "body", None)
if doc_body is None:
raise TextExtractionError("DOCX body not found")
it = iter(doc_body)
part = next(it, None) part = next(it, None)
i = 0 i = 0
while part is not None: while part is not None:

View File

@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Literal from typing import Literal, NotRequired
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.base_node_data import BaseNodeData
@ -10,11 +11,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.entities import VariableSelector
class StructuredOutputConfig(TypedDict):
schema: Mapping[str, object]
name: NotRequired[str]
description: NotRequired[str]
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
provider: str provider: str
name: str name: str
mode: LLMMode mode: LLMMode
completion_params: dict[str, Any] = Field(default_factory=dict) completion_params: dict[str, object] = Field(default_factory=dict)
class ContextConfig(BaseModel): class ContextConfig(BaseModel):
@ -33,7 +40,7 @@ class VisionConfig(BaseModel):
@field_validator("configs", mode="before") @field_validator("configs", mode="before")
@classmethod @classmethod
def convert_none_configs(cls, v: Any): def convert_none_configs(cls, v: object):
if v is None: if v is None:
return VisionConfigOptions() return VisionConfigOptions()
return v return v
@ -44,7 +51,7 @@ class PromptConfig(BaseModel):
@field_validator("jinja2_variables", mode="before") @field_validator("jinja2_variables", mode="before")
@classmethod @classmethod
def convert_none_jinja2_variables(cls, v: Any): def convert_none_jinja2_variables(cls, v: object):
if v is None: if v is None:
return [] return []
return v return v
@ -67,7 +74,7 @@ class LLMNodeData(BaseNodeData):
memory: MemoryConfig | None = None memory: MemoryConfig | None = None
context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None structured_output: StructuredOutputConfig | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name. # We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field( reasoning_format: Literal["separated", "tagged"] = Field(
@ -90,11 +97,30 @@ class LLMNodeData(BaseNodeData):
@field_validator("prompt_config", mode="before") @field_validator("prompt_config", mode="before")
@classmethod @classmethod
def convert_none_prompt_config(cls, v: Any): def convert_none_prompt_config(cls, v: object):
if v is None: if v is None:
return PromptConfig() return PromptConfig()
return v return v
@field_validator("structured_output", mode="before")
@classmethod
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
if not isinstance(v, Mapping):
return v
schema = v.get("schema")
if schema is None:
return None
normalized: StructuredOutputConfig = {"schema": schema}
name = v.get("name")
description = v.get("description")
if isinstance(name, str):
normalized["name"] = name
if isinstance(description, str):
normalized["description"] = description
return normalized
@property @property
def structured_output_enabled(self) -> bool: def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None return self.structured_output_switch_on and self.structured_output is not None

View File

@ -9,6 +9,7 @@ import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.errors import OutputParserError
@ -74,6 +75,7 @@ from .entities import (
LLMNodeChatModelMessage, LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate, LLMNodeCompletionModelPromptTemplate,
LLMNodeData, LLMNodeData,
StructuredOutputConfig,
) )
from .exc import ( from .exc import (
InvalidContextStructureError, InvalidContextStructureError,
@ -88,6 +90,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
class LLMNode(Node[LLMNodeData]): class LLMNode(Node[LLMNodeData]):
@ -358,7 +361,7 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None = None, stop: Sequence[str] | None = None,
user_id: str, user_id: str,
structured_output_enabled: bool, structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None, structured_output: StructuredOutputConfig | None = None,
file_saver: LLMFileSaver, file_saver: LLMFileSaver,
file_outputs: list[File], file_outputs: list[File],
node_id: str, node_id: str,
@ -371,8 +374,10 @@ class LLMNode(Node[LLMNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled: if structured_output_enabled:
if structured_output is None:
raise LLMNodeError("Please provide a valid structured output schema")
output_schema = LLMNode.fetch_structured_output_schema( output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {}, structured_output=structured_output,
) )
request_start_time = time.perf_counter() request_start_time = time.perf_counter()
@ -924,6 +929,12 @@ class LLMNode(Node[LLMNodeData]):
# Extract clean text and reasoning from <think> tags # Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
structured_output = (
dict(invoke_result.structured_output)
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
else None
)
event = ModelInvokeCompletedEvent( event = ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode # Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text, text=clean_text if reasoning_format == "separated" else full_text,
@ -932,7 +943,7 @@ class LLMNode(Node[LLMNodeData]):
# Reasoning content for workflow variables and downstream nodes # Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
# Pass structured output if enabled # Pass structured output if enabled
structured_output=getattr(invoke_result, "structured_output", None), structured_output=structured_output,
) )
if request_latency is not None: if request_latency is not None:
event.usage.latency = round(request_latency, 3) event.usage.latency = round(request_latency, 3)
@ -966,27 +977,18 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod @staticmethod
def fetch_structured_output_schema( def fetch_structured_output_schema(
*, *,
structured_output: Mapping[str, Any], structured_output: StructuredOutputConfig,
) -> dict[str, Any]: ) -> dict[str, object]:
""" """
Fetch the structured output schema from the node data. Fetch the structured output schema from the node data.
Returns: Returns:
dict[str, Any]: The structured output schema dict[str, object]: The structured output schema
""" """
if not structured_output: schema = structured_output.get("schema")
if not schema:
raise LLMNodeError("Please provide a valid structured output schema") raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) return _JSON_OBJECT_ADAPTER.validate_python(schema)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod @staticmethod
def _save_multimodal_output_and_convert_result_to_markdown( def _save_multimodal_output_and_convert_result_to_markdown(

View File

@ -1,7 +1,10 @@
from enum import StrEnum from __future__ import annotations
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator from enum import StrEnum
from typing import Annotated, Any, Literal, TypeAlias, cast
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.enums import BuiltinNodeTypes, NodeType
@ -9,6 +12,12 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType from dify_graph.variables.types import SegmentType
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
LoopValueMapping: TypeAlias = dict[str, LoopValue]
VariableSelector: TypeAlias = list[str]
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
_VALID_VAR_TYPE = frozenset( _VALID_VAR_TYPE = frozenset(
[ [
SegmentType.STRING, SegmentType.STRING,
@ -29,6 +38,36 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
return seg_type return seg_type
def _validate_loop_value(value: object) -> LoopValue:
if value is None or isinstance(value, (str, int, float, bool)):
return cast(LoopValue, value)
if isinstance(value, list):
return [_validate_loop_value(item) for item in value]
if isinstance(value, dict):
normalized: dict[str, LoopValue] = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop values only support string object keys")
normalized[key] = _validate_loop_value(item)
return normalized
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
if not isinstance(value, dict):
raise TypeError("Loop outputs must be an object")
normalized: LoopValueMapping = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop output keys must be strings")
normalized[key] = _validate_loop_value(item)
return normalized
class LoopVariableData(BaseModel): class LoopVariableData(BaseModel):
""" """
Loop Variable Data. Loop Variable Data.
@ -37,7 +76,29 @@ class LoopVariableData(BaseModel):
label: str label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"] value_type: Literal["variable", "constant"]
value: Any | list[str] | None = None value: LoopValue | VariableSelector | None = None
@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
value_type = validation_info.data.get("value_type")
if value_type == "variable":
if value is None:
raise ValueError("Variable loop inputs require a selector")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if value_type == "constant":
return _validate_loop_value(value)
raise ValueError(f"Unknown loop variable value type: {value_type}")
def require_variable_selector(self) -> VariableSelector:
if self.value_type != "variable":
raise ValueError(f"Expected variable loop input, got {self.value_type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def require_constant_value(self) -> LoopValue:
if self.value_type != "constant":
raise ValueError(f"Expected constant loop input, got {self.value_type}")
return _validate_loop_value(self.value)
class LoopNodeData(BaseLoopNodeData): class LoopNodeData(BaseLoopNodeData):
@ -46,14 +107,14 @@ class LoopNodeData(BaseLoopNodeData):
break_conditions: list[Condition] # Conditions to break the loop break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"] logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict) outputs: LoopValueMapping = Field(default_factory=dict)
@field_validator("outputs", mode="before") @field_validator("outputs", mode="before")
@classmethod @classmethod
def validate_outputs(cls, v): def validate_outputs(cls, value: object) -> LoopValueMapping:
if v is None: if value is None:
return {} return {}
return v return _validate_loop_value_mapping(value)
class LoopStartNodeData(BaseNodeData): class LoopStartNodeData(BaseNodeData):
@ -77,8 +138,8 @@ class LoopState(BaseLoopState):
Loop State. Loop State.
""" """
outputs: list[Any] = Field(default_factory=list) outputs: list[LoopValue] = Field(default_factory=list)
current_output: Any = None current_output: LoopValue | None = None
class MetaData(BaseLoopState.MetaData): class MetaData(BaseLoopState.MetaData):
""" """
@ -87,7 +148,7 @@ class LoopState(BaseLoopState):
loop_length: int loop_length: int
def get_last_output(self) -> Any: def get_last_output(self) -> LoopValue | None:
""" """
Get last output. Get last output.
""" """
@ -95,7 +156,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1] return self.outputs[-1]
return None return None
def get_current_output(self) -> Any: def get_current_output(self) -> LoopValue | None:
""" """
Get current output. Get current output.
""" """

View File

@ -3,7 +3,7 @@ import json
import logging import logging
from collections.abc import Callable, Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import ( from dify_graph.enums import (
@ -29,7 +29,7 @@ from dify_graph.node_events import (
) )
from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
from dify_graph.utils.condition.processor import ConditionProcessor from dify_graph.utils.condition.processor import ConditionProcessor
from dify_graph.variables import Segment, SegmentType from dify_graph.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
break_conditions = self.node_data.break_conditions break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count} inputs: dict[str, object] = {"loop_count": loop_count}
if not self.node_data.start_node_id: if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found") raise ValueError(f"field start_node_id in loop {self._node_id} not found")
@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
root_node_id = self.node_data.start_node_id root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool # Initialize loop variables in the original variable pool
loop_variable_selectors = {} loop_variable_selectors: dict[str, list[str]] = {}
if self.node_data.loop_variables: if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), "constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
"variable": lambda var: ( "variable": lambda var: (
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
if var.value is not None
else None
), ),
} }
for loop_variable in self.node_data.loop_variables: for loop_variable in self.node_data.loop_variables:
@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
condition_processor = ConditionProcessor() condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {} loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage() loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id) loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration # Collect loop variable values after iteration
single_loop_variable = {} single_loop_variable: dict[str, LoopValue] = {}
for key, selector in loop_variable_selectors.items(): for key, selector in loop_variable_selectors.items():
segment = self.graph_runtime_state.variable_pool.get(selector) segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None single_loop_variable[key] = segment.value if segment else None
@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, object],
node_id: str, node_id: str,
node_data: LoopNodeData, node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
variable_mapping = {} variable_mapping: dict[str, Sequence[str]] = {}
# Extract loop node IDs statically from graph_config # Extract loop node IDs statically from graph_config
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Get node configs from graph_config # Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} raw_nodes = graph_config.get("nodes")
node_configs: dict[str, Mapping[str, object]] = {}
if isinstance(raw_nodes, list):
for raw_node in raw_nodes:
if not isinstance(raw_node, dict):
continue
raw_node_id = raw_node.get("id")
if isinstance(raw_node_id, str):
node_configs[raw_node_id] = raw_node
for sub_node_id, sub_node_config in node_configs.items(): for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id: sub_node_data = sub_node_config.get("data")
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
continue continue
# variable selector to variable mapping # variable selector to variable mapping
@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
for loop_variable in node_data.loop_variables or []: for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable": if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping # add loop variable to variable mapping
selector = loop_variable.value selector = loop_variable.require_variable_selector()
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop # remove variable out from loop
@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return variable_mapping return variable_mapping
@classmethod @classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
""" """
Extract node IDs that belong to a specific loop from graph configuration. Extract node IDs that belong to a specific loop from graph configuration.
@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
:param loop_node_id: the ID of the loop node :param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop :return: set of node IDs that belong to the loop
""" """
loop_node_ids = set() loop_node_ids: set[str] = set()
# Find all nodes that belong to this loop # Find all nodes that belong to this loop
nodes = graph_config.get("nodes", []) raw_nodes = graph_config.get("nodes")
for node in nodes: if not isinstance(raw_nodes, list):
node_data = node.get("data", {}) return loop_node_ids
for node in raw_nodes:
if not isinstance(node, dict):
continue
node_data = node.get("data")
if not isinstance(node_data, dict):
continue
if node_data.get("loop_id") == loop_node_id: if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id") node_id = node.get("id")
if node_id: if node_id:
@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return loop_node_ids return loop_node_ids
@staticmethod @staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
"""Get the appropriate segment type for a constant value.""" """Get the appropriate segment type for a constant value."""
# TODO: Refactor for maintainability: # TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py) # 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
@ -389,11 +406,15 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
SegmentType.ARRAY_OBJECT, SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING, SegmentType.ARRAY_STRING,
]: ]:
if original_value and isinstance(original_value, str): # New typed payloads may already provide native lists, while legacy
value = json.loads(original_value) # configs still serialize array constants as JSON strings.
else: if isinstance(original_value, str):
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type) value = json.loads(original_value) if original_value else []
elif original_value is None:
# Preserve legacy behavior: treat missing/empty array constants as [].
value = [] value = []
else:
value = original_value
else: else:
raise AssertionError("this statement should be unreachable.") raise AssertionError("this statement should be unreachable.")
try: try:

View File

@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal from typing import Annotated, Literal
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@ -6,6 +6,7 @@ from pydantic import (
Field, Field,
field_validator, field_validator,
) )
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.base_node_data import BaseNodeData
@ -55,7 +56,7 @@ class ParameterConfig(BaseModel):
@field_validator("name", mode="before") @field_validator("name", mode="before")
@classmethod @classmethod
def validate_name(cls, value) -> str: def validate_name(cls, value: object) -> str:
if not value: if not value:
raise ValueError("Parameter name is required") raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}: if value in {"__reason", "__is_success"}:
@ -79,6 +80,23 @@ class ParameterConfig(BaseModel):
return element_type return element_type
class JsonSchemaArrayItems(TypedDict):
type: str
class ParameterJsonSchemaProperty(TypedDict, total=False):
description: str
type: str
items: JsonSchemaArrayItems
enum: list[str]
class ParameterJsonSchema(TypedDict):
type: Literal["object"]
properties: dict[str, ParameterJsonSchemaProperty]
required: list[str]
class ParameterExtractorNodeData(BaseNodeData): class ParameterExtractorNodeData(BaseNodeData):
""" """
Parameter Extractor Node Data. Parameter Extractor Node Data.
@ -95,19 +113,19 @@ class ParameterExtractorNodeData(BaseNodeData):
@field_validator("reasoning_mode", mode="before") @field_validator("reasoning_mode", mode="before")
@classmethod @classmethod
def set_reasoning_mode(cls, v) -> str: def set_reasoning_mode(cls, v: object) -> str:
return v or "function_call" return str(v) if v else "function_call"
def get_parameter_json_schema(self): def get_parameter_json_schema(self) -> ParameterJsonSchema:
""" """
Get parameter json schema. Get parameter json schema.
:return: parameter json schema :return: parameter json schema
""" """
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters: for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description} parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
if parameter.type == SegmentType.STRING: if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string" parameter_schema["type"] = "string"
@ -118,7 +136,7 @@ class ParameterExtractorNodeData(BaseNodeData):
raise AssertionError("element type should not be None.") raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value} parameter_schema["items"] = {"type": element_type.value}
else: else:
parameter_schema["type"] = parameter.type parameter_schema["type"] = parameter.type.value
if parameter.options: if parameter.options:
parameter_schema["enum"] = parameter.options parameter_schema["enum"] = parameter.options

View File

@ -5,6 +5,8 @@ import uuid
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from pydantic import TypeAdapter
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@ -63,6 +65,7 @@ from .prompts import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
if TYPE_CHECKING: if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams from dify_graph.entities import GraphInitParams
@ -70,7 +73,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState from dify_graph.runtime import GraphRuntimeState
def extract_json(text): def extract_json(text: str) -> str | None:
""" """
From a given JSON started from '{' or '[' extract the complete JSON object. From a given JSON started from '{' or '[' extract the complete JSON object.
""" """
@ -396,10 +399,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
) )
# generate tool # generate tool
parameter_schema = node_data.get_parameter_json_schema()
tool = PromptMessageTool( tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME, name=FUNCTION_CALLING_EXTRACTOR_NAME,
description="Extract parameters from the natural language text", description="Extract parameters from the natural language text",
parameters=node_data.get_parameter_json_schema(), parameters={
"type": parameter_schema["type"],
"properties": dict(parameter_schema["properties"]),
"required": list(parameter_schema["required"]),
},
) )
return prompt_messages, [tool] return prompt_messages, [tool]
@ -606,19 +614,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
else: else:
return None return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict): def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
""" """
Transform result into standard format. Transform result into standard format.
""" """
transformed_result: dict[str, Any] = {} transformed_result: dict[str, object] = {}
for parameter in data.parameters: for parameter in data.parameters:
if parameter.name in result: if parameter.name in result:
param_value = result[parameter.name] param_value = result[parameter.name]
# transform value # transform value
if parameter.type == SegmentType.NUMBER: if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value) if isinstance(param_value, (bool, int, float, str)):
if transformed is not None: numeric_value: bool | int | float | str = param_value
transformed_result[parameter.name] = transformed transformed = self._transform_number(numeric_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN: elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)): if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name]) transformed_result[parameter.name] = bool(result[parameter.name])
@ -665,7 +675,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return transformed_result return transformed_result
def _extract_complete_json_response(self, result: str) -> dict | None: def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
""" """
Extract complete json response. Extract complete json response.
""" """
@ -676,11 +686,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:]) json_str = extract_json(result[idx:])
if json_str: if json_str:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str)) return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
logger.info("extra error: %s", result) logger.info("extra error: %s", result)
return None return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
""" """
Extract json from tool call. Extract json from tool call.
""" """
@ -694,16 +704,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:]) json_str = extract_json(result[idx:])
if json_str: if json_str:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str)) return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
logger.info("extra error: %s", result) logger.info("extra error: %s", result)
return None return None
def _generate_default_result(self, data: ParameterExtractorNodeData): def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
""" """
Generate default result. Generate default result.
""" """
result: dict[str, Any] = {} result: dict[str, object] = {}
for parameter in data.parameters: for parameter in data.parameters:
if parameter.type == "number": if parameter.type == "number":
result[parameter.name] = 0 result[parameter.name] = 0

View File

@ -1,12 +1,66 @@
from typing import Any, Literal, Union from __future__ import annotations
from pydantic import BaseModel, field_validator from typing import Literal, TypeAlias, cast
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from typing_extensions import TypedDict
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.enums import BuiltinNodeTypes, NodeType
ToolConfigurationValue: TypeAlias = str | int | float | bool
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
VariableSelector: TypeAlias = list[str]
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class WorkflowToolInputValue(TypedDict):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
class ToolInputPayload(BaseModel):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> ToolInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "mixed":
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
if input_type == "variable":
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type == "constant":
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
raise ValueError(f"Unknown tool input type: {input_type}")
def require_variable_selector(self) -> VariableSelector:
if self.type != "variable":
raise ValueError(f"Expected variable tool input, got {self.type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
if isinstance(value, (str, int, float, bool)):
return cast(ToolConfigurationEntry, value)
if isinstance(value, dict):
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
class ToolEntity(BaseModel): class ToolEntity(BaseModel):
provider_id: str provider_id: str
@ -14,52 +68,29 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy provider_name: str # redundancy
tool_name: str tool_name: str
tool_label: str # redundancy tool_label: str # redundancy
tool_configurations: dict[str, Any] tool_configurations: ToolConfigurations
credential_id: str | None = None credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before") @field_validator("tool_configurations", mode="before")
@classmethod @classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo): def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
if not isinstance(value, dict): if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary") raise TypeError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}): normalized: ToolConfigurations = {}
value = values.data.get("tool_configurations", {}).get(key) for key, item in value.items():
if not isinstance(value, str | int | float | bool): if not isinstance(key, str):
raise ValueError(f"{key} must be a string") raise TypeError("tool_configurations keys must be strings")
normalized[key] = _validate_tool_configuration_entry(item)
return value return normalized
class ToolNodeData(BaseNodeData, ToolEntity): class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = BuiltinNodeTypes.TOOL type: NodeType = BuiltinNodeTypes.TOOL
class ToolInput(BaseModel): class ToolInput(ToolInputPayload):
# TODO: check this type pass
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if value is None:
return typ
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
return typ
tool_parameters: dict[str, ToolInput] tool_parameters: dict[str, ToolInput]
# The version of the tool parameter. # The version of the tool parameter.
@ -69,7 +100,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
@field_validator("tool_parameters", mode="before") @field_validator("tool_parameters", mode="before")
@classmethod @classmethod
def filter_none_tool_inputs(cls, value): def filter_none_tool_inputs(cls, value: object) -> object:
if not isinstance(value, dict): if not isinstance(value, dict):
return value return value
@ -80,8 +111,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
} }
@staticmethod @staticmethod
def _has_valid_value(tool_input): def _has_valid_value(tool_input: object) -> bool:
"""Check if the value is valid""" """Check if the value is valid"""
if isinstance(tool_input, dict): if isinstance(tool_input, dict):
return tool_input.get("value") is not None return tool_input.get("value") is not None
return getattr(tool_input, "value", None) is not None if isinstance(tool_input, ToolNodeData.ToolInput):
return tool_input.value is not None
return False

View File

@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
continue continue
tool_input = node_data.tool_parameters[parameter_name] tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable": if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value) variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None: if variable is None:
if parameter.required: if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist") raise ToolParameterError(f"Variable {variable_selector} does not exist")
continue continue
parameter_value = variable.value parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}: elif tool_input.type in {"mixed", "constant"}:
@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors: for selector in selectors:
result[selector.variable] = selector.value_selector result[selector.variable] = selector.value_selector
case "variable": case "variable":
selector_key = ".".join(input.value) variable_selector = input.require_variable_selector()
result[f"#{selector_key}#"] = input.value selector_key = ".".join(variable_selector)
result[f"#{selector_key}#"] = variable_selector
case "constant": case "constant":
pass pass

View File

@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from dify_graph.variables import SegmentType, VariableBase from dify_graph.variables import Segment, SegmentType, VariableBase
from .node_data import VariableAssignerData, WriteMode from .node_data import VariableAssignerData, WriteMode
@ -74,23 +74,29 @@ class VariableAssignerNode(Node[VariableAssignerData]):
if not isinstance(original_variable, VariableBase): if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found") raise VariableOperatorNodeError("assigned variable not found")
income_value: Segment
updated_variable: VariableBase
match self.node_data.write_mode: match self.node_data.write_mode:
case WriteMode.OVER_WRITE: case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value: if input_value is None:
raise VariableOperatorNodeError("input value not found") raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_variable = original_variable.model_copy(update={"value": income_value.value}) updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND: case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value: if input_value is None:
raise VariableOperatorNodeError("input value not found") raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_value = original_variable.value + [income_value.value] updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value}) updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR: case WriteMode.CLEAR:
income_value = SegmentType.get_zero_value(original_variable.value_type) income_value = SegmentType.get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable. # Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

View File

@ -66,6 +66,11 @@ class GraphExecutionProtocol(Protocol):
exceptions_count: int exceptions_count: int
pause_reasons: list[PauseReason] pause_reasons: list[PauseReason]
@property
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
"""Return node execution state keyed by node id for resume support."""
...
def start(self) -> None: def start(self) -> None:
"""Transition execution into the running state.""" """Transition execution into the running state."""
... ...
@ -91,6 +96,12 @@ class GraphExecutionProtocol(Protocol):
... ...
class NodeExecutionProtocol(Protocol):
"""Structural interface for per-node execution state used during resume."""
execution_id: str | None
class ResponseStreamCoordinatorProtocol(Protocol): class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator.""" """Structural interface for response stream coordinator."""

View File

@ -13,21 +13,6 @@ controllers/console/workspace/trigger_providers.py
controllers/service_api/app/annotation.py controllers/service_api/app/annotation.py
controllers/web/workflow_events.py controllers/web/workflow_events.py
core/agent/fc_agent_runner.py core/agent/fc_agent_runner.py
core/app/apps/advanced_chat/app_generator.py
core/app/apps/advanced_chat/app_runner.py
core/app/apps/advanced_chat/generate_task_pipeline.py
core/app/apps/agent_chat/app_generator.py
core/app/apps/base_app_generate_response_converter.py
core/app/apps/base_app_generator.py
core/app/apps/chat/app_generator.py
core/app/apps/common/workflow_response_converter.py
core/app/apps/completion/app_generator.py
core/app/apps/pipeline/pipeline_generator.py
core/app/apps/pipeline/pipeline_runner.py
core/app/apps/workflow/app_generator.py
core/app/apps/workflow/app_runner.py
core/app/apps/workflow/generate_task_pipeline.py
core/app/apps/workflow_app_runner.py
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/datasource/datasource_manager.py core/datasource/datasource_manager.py
core/external_data_tool/api/api.py core/external_data_tool/api/api.py
@ -108,35 +93,6 @@ core/tools/workflow_as_tool/provider.py
core/trigger/debug/event_selectors.py core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py core/trigger/entities/entities.py
core/trigger/provider.py core/trigger/provider.py
core/workflow/workflow_entry.py
dify_graph/entities/workflow_execution.py
dify_graph/file/file_manager.py
dify_graph/graph_engine/error_handler.py
dify_graph/graph_engine/layers/execution_limits.py
dify_graph/nodes/agent/agent_node.py
dify_graph/nodes/base/node.py
dify_graph/nodes/code/code_node.py
dify_graph/nodes/datasource/datasource_node.py
dify_graph/nodes/document_extractor/node.py
dify_graph/nodes/human_input/human_input_node.py
dify_graph/nodes/if_else/if_else_node.py
dify_graph/nodes/iteration/iteration_node.py
dify_graph/nodes/knowledge_index/knowledge_index_node.py
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
dify_graph/nodes/list_operator/node.py
dify_graph/nodes/llm/node.py
dify_graph/nodes/loop/loop_node.py
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
dify_graph/nodes/question_classifier/question_classifier_node.py
dify_graph/nodes/start/start_node.py
dify_graph/nodes/template_transform/template_transform_node.py
dify_graph/nodes/tool/tool_node.py
dify_graph/nodes/trigger_plugin/trigger_event_node.py
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py extensions/otel/instrumentation.py
extensions/otel/runtime.py extensions/otel/runtime.py

View File

@ -1013,7 +1013,7 @@ class TestAdvancedChatAppGeneratorInternals:
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) refreshed = _refresh_model(session=None, model=source_model)
assert refreshed is detached_model assert refreshed is detached_model

View File

@ -0,0 +1,110 @@
from collections.abc import Iterator
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse
from core.errors.error import QuotaExceededError
class DummyResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = AppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]:
return {"mode": "blocking-full", "task_id": blocking_response.task_id}
@classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]:
return {"mode": "blocking-simple", "task_id": blocking_response.task_id}
@classmethod
def convert_stream_full_response(cls, stream_response: Iterator[object]):
for _ in stream_response:
yield {"mode": "stream-full"}
@classmethod
def convert_stream_simple_response(cls, stream_response: Iterator[object]):
for _ in stream_response:
yield {"mode": "stream-simple"}
def test_convert_routes_to_full_or_simple_modes() -> None:
blocking = AppBlockingResponse(task_id="task-1")
assert DummyResponseConverter.convert(blocking, InvokeFrom.DEBUGGER) == {
"mode": "blocking-full",
"task_id": "task-1",
}
assert DummyResponseConverter.convert(blocking, InvokeFrom.WEB_APP) == {
"mode": "blocking-simple",
"task_id": "task-1",
}
assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.SERVICE_API)) == [{"mode": "stream-full"}]
assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.WEB_APP)) == [{"mode": "stream-simple"}]
def test_get_simple_metadata_preserves_new_retriever_fields() -> None:
metadata = {
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset",
"document_id": "document-1",
"segment_id": "segment-1",
"position": 1,
"data_source_type": "upload_file",
"document_name": "Document",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "hash",
"content": "content",
"page": 5,
"title": "Title",
"files": [{"id": "file-1"}],
"summary": "summary",
}
],
"annotation_reply": "hidden",
"usage": {"latency": 0.1},
}
result = DummyResponseConverter._get_simple_metadata(metadata)
assert result == {
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset",
"document_id": "document-1",
"segment_id": "segment-1",
"position": 1,
"data_source_type": "upload_file",
"document_name": "Document",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "hash",
"content": "content",
"page": 5,
"title": "Title",
"files": [{"id": "file-1"}],
"summary": "summary",
}
]
}
def test_error_to_stream_response_uses_specific_and_fallback_mappings() -> None:
quota_response = DummyResponseConverter._error_to_stream_response(QuotaExceededError())
fallback_response = DummyResponseConverter._error_to_stream_response(RuntimeError("boom"))
assert quota_response["code"] == "provider_quota_exceeded"
assert quota_response["status"] == 400
assert fallback_response == {
"code": "internal_server_error",
"message": "Internal Server Error, please contact support.",
"status": 500,
}

View File

@ -33,6 +33,79 @@ from dify_graph.system_variable import SystemVariable
class TestWorkflowBasedAppRunner: class TestWorkflowBasedAppRunner:
def test_get_graph_items_rejects_non_mapping_entries(self):
with pytest.raises(ValueError, match="nodes in workflow graph must be mappings"):
WorkflowBasedAppRunner._get_graph_items({"nodes": ["bad"], "edges": []})
with pytest.raises(ValueError, match="edges in workflow graph must be mappings"):
WorkflowBasedAppRunner._get_graph_items({"nodes": [], "edges": ["bad"]})
def test_extract_start_node_id_handles_missing_and_invalid_values(self):
assert WorkflowBasedAppRunner._extract_start_node_id(None) is None
assert WorkflowBasedAppRunner._extract_start_node_id({"data": "invalid"}) is None
assert WorkflowBasedAppRunner._extract_start_node_id({"data": {"start_node_id": 123}}) is None
assert WorkflowBasedAppRunner._extract_start_node_id({"data": {"start_node_id": "start-node"}}) == "start-node"
def test_build_single_node_graph_config_keeps_target_related_and_start_nodes(self):
graph_config, target_node_config = WorkflowBasedAppRunner._build_single_node_graph_config(
graph_config={
"nodes": [
{"id": "start-node", "data": {"type": "start", "version": "1"}},
{
"id": "loop-node",
"data": {"type": "loop", "version": "1", "start_node_id": "start-node"},
},
{
"id": "loop-child",
"data": {"type": "answer", "version": "1", "loop_id": "loop-node"},
},
{"id": "outside-node", "data": {"type": "answer", "version": "1"}},
],
"edges": [
{"source": "start-node", "target": "loop-node"},
{"source": "loop-node", "target": "loop-child"},
{"source": "loop-node", "target": "outside-node"},
],
},
node_id="loop-node",
node_type_filter_key="loop_id",
)
assert [node["id"] for node in graph_config["nodes"]] == ["start-node", "loop-node", "loop-child"]
assert graph_config["edges"] == [
{"source": "start-node", "target": "loop-node"},
{"source": "loop-node", "target": "loop-child"},
]
assert target_node_config["id"] == "loop-node"
def test_build_agent_strategy_info_validates_payload(self):
event = NodeRunStartedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
node_title="Start",
start_at=datetime.utcnow(),
extras={"agent_strategy": {"name": "planner", "icon": "robot"}},
)
strategy = WorkflowBasedAppRunner._build_agent_strategy_info(event)
assert strategy is not None
assert strategy.name == "planner"
assert strategy.icon == "robot"
def test_build_agent_strategy_info_returns_none_for_invalid_payload(self):
event = NodeRunStartedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
node_title="Start",
start_at=datetime.utcnow(),
extras={"agent_strategy": {"name": "planner", "extra": "ignored"}},
)
assert WorkflowBasedAppRunner._build_agent_strategy_info(event) is None
def test_resolve_user_from(self): def test_resolve_user_from(self):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
@ -174,6 +247,34 @@ class TestWorkflowBasedAppRunner:
assert paused_event.paused_nodes == ["node-1"] assert paused_event.paused_nodes == ["node-1"]
assert emails assert emails
def test_enqueue_human_input_notifications_skips_invalid_reasons_and_logs_failures(self, monkeypatch):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
seen_calls: list[tuple[dict[str, object], str]] = []
class _Dispatch:
def apply_async(self, *, kwargs, queue):
seen_calls.append((kwargs, queue))
raise RuntimeError("boom")
logged: list[str] = []
monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", _Dispatch())
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.logger",
SimpleNamespace(exception=lambda message, form_id: logged.append(f"{message}:{form_id}")),
)
runner._enqueue_human_input_notifications(
[
object(),
HumanInputRequired(form_id="", form_content="content", node_id="node", node_title="Node"),
HumanInputRequired(form_id="form-1", form_content="content", node_id="node", node_title="Node"),
]
)
assert seen_calls == [({"form_id": "form-1", "node_title": "Node"}, "mail")]
assert logged == ["Failed to enqueue human input email task for form %s:form-1"]
def test_handle_node_events_publishes_queue_events(self): def test_handle_node_events_publishes_queue_events(self):
published: list[object] = [] published: list[object] = []

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from copy import deepcopy
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -33,8 +33,8 @@ def _make_graph_state():
], ],
) )
def test_run_uses_single_node_execution_branch( def test_run_uses_single_node_execution_branch(
single_iteration_run: Any, single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
single_loop_run: Any, single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
) -> None: ) -> None:
app_config = MagicMock() app_config = MagicMock()
app_config.app_id = "app" app_config.app_id = "app"
@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
"break_conditions": [], "break_conditions": [],
"logical_operator": "and", "logical_operator": "and",
}, },
},
{
"id": "other-node",
"data": {
"type": "answer",
"title": "Answer",
},
},
],
"edges": [
{
"source": "other-node",
"target": "loop-node",
} }
], ],
"edges": [],
} }
original_graph_dict = deepcopy(workflow.graph_dict)
_, _, graph_runtime_state = _make_graph_state() _, _, graph_runtime_state = _make_graph_state()
seen_configs: list[object] = [] seen_configs: list[object] = []
@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
seen_configs.append(value) seen_configs.append(value)
return original_validate_python(value) return original_validate_python(value)
class FakeNodeClass:
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
with ( with (
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"), patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()), patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"), patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"), patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
): ):
runner._get_graph_and_variable_pool_for_single_node_run( runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow, workflow=workflow,
@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
) )
assert seen_configs == [workflow.graph_dict["nodes"][0]] assert seen_configs == [workflow.graph_dict["nodes"][0]]
assert workflow.graph_dict == original_graph_dict
graph_config = graph_init.call_args.kwargs["graph_config"]
assert graph_config is not workflow.graph_dict
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
assert graph_config["edges"] == []

View File

@ -0,0 +1,46 @@
import pytest
from pydantic import ValidationError
from core.workflow.nodes.agent.entities import AgentNodeData
def test_agent_input_accepts_variable_selector_and_mixed_values() -> None:
node_data = AgentNodeData.model_validate(
{
"title": "Agent",
"agent_strategy_provider_name": "provider",
"agent_strategy_name": "strategy",
"agent_strategy_label": "Strategy",
"agent_parameters": {
"query": {"type": "variable", "value": ["start", "query"]},
"tools": {"type": "mixed", "value": [{"provider": "builtin", "name": "search"}]},
},
}
)
assert node_data.agent_parameters["query"].value == ["start", "query"]
assert node_data.agent_parameters["tools"].value == [{"provider": "builtin", "name": "search"}]
def test_agent_input_rejects_invalid_variable_selector_and_unknown_type() -> None:
with pytest.raises(ValidationError):
AgentNodeData.model_validate(
{
"title": "Agent",
"agent_strategy_provider_name": "provider",
"agent_strategy_name": "strategy",
"agent_strategy_label": "Strategy",
"agent_parameters": {"query": {"type": "variable", "value": "start.query"}},
}
)
with pytest.raises(ValidationError, match="Unknown agent input type"):
AgentNodeData.model_validate(
{
"title": "Agent",
"agent_strategy_provider_name": "provider",
"agent_strategy_name": "strategy",
"agent_strategy_label": "Strategy",
"agent_parameters": {"query": {"type": "unsupported", "value": "hello"}},
}
)

View File

@ -0,0 +1,125 @@
from types import SimpleNamespace
import pytest
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.agent.exceptions import AgentVariableNotFoundError
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
def test_filter_mcp_type_tool_depends_on_strategy_meta_version() -> None:
runtime_support = AgentRuntimeSupport()
tools = [
{"type": ToolProviderType.BUILT_IN, "tool_name": "search"},
{"type": ToolProviderType.MCP, "tool_name": "mcp-tool"},
]
filtered_tools = runtime_support._filter_mcp_type_tool(SimpleNamespace(meta_version="0.0.1"), tools)
preserved_tools = runtime_support._filter_mcp_type_tool(SimpleNamespace(meta_version="0.0.2"), tools)
assert filtered_tools == [{"type": ToolProviderType.BUILT_IN, "tool_name": "search"}]
assert preserved_tools == tools
def test_normalize_tool_payloads_keeps_enabled_tools_and_resolves_values() -> None:
runtime_support = AgentRuntimeSupport()
variable_pool = SimpleNamespace(get=lambda selector: SimpleNamespace(value=f"resolved:{'.'.join(selector)}"))
normalized_tools = runtime_support._normalize_tool_payloads(
strategy=SimpleNamespace(meta_version="0.0.2"),
tools=[
{
"enabled": True,
"tool_name": "search",
"schemas": {"ignored": True},
"parameters": {
"query": {
"auto": 0,
"value": {"type": "variable", "value": ["start", "query"]},
},
"top_k": {
"auto": 0,
"value": {"type": "constant", "value": 3},
},
"optional": {"auto": 1, "value": {"type": "constant", "value": "skip"}},
},
"settings": {
"region": {"value": "us"},
"safe": {"value": True},
},
},
{"enabled": False, "tool_name": "disabled"},
],
variable_pool=variable_pool,
)
assert normalized_tools == [
{
"enabled": True,
"tool_name": "search",
"parameters": {"query": "resolved:start.query", "top_k": 3, "optional": None},
"settings": {"region": "us", "safe": True},
}
]
def test_resolve_tool_parameters_raises_for_missing_variable() -> None:
runtime_support = AgentRuntimeSupport()
variable_pool = SimpleNamespace(get=lambda _selector: None)
with pytest.raises(AgentVariableNotFoundError, match=r"\['start', 'query'\]"):
runtime_support._resolve_tool_parameters(
tool={
"parameters": {
"query": {
"auto": 0,
"value": {"type": "variable", "value": ["start", "query"]},
}
}
},
variable_pool=variable_pool,
)
def test_build_credentials_collects_valid_tool_credentials_only() -> None:
runtime_support = AgentRuntimeSupport()
credentials = runtime_support.build_credentials(
parameters={
"tools": [
{
"credential_id": "cred-1",
"identity": {
"author": "author",
"name": "tool",
"label": {"en_US": "Tool"},
"provider": "provider-a",
},
},
{
"credential_id": "cred-2",
"identity": {"author": "author"},
},
{
"credential_id": None,
"identity": {
"author": "author",
"name": "tool",
"label": {"en_US": "Tool"},
"provider": "provider-b",
},
},
"invalid",
]
}
)
assert credentials.tool_credentials == {"provider-a": "cred-1"}
def test_coerce_named_json_objects_requires_string_keys_and_json_object_values() -> None:
runtime_support = AgentRuntimeSupport()
assert runtime_support._coerce_named_json_objects({"valid": {"value": 1}}) == {"valid": {"value": 1}}
assert runtime_support._coerce_named_json_objects({1: {"value": 1}}) is None
assert runtime_support._coerce_named_json_objects({"invalid": object()}) is None

View File

@ -13,7 +13,9 @@ from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities import GraphInitParams from dify_graph.entities import GraphInitParams
from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities import LLMMode
from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultWithStructuredOutput, LLMUsage
from dify_graph.model_runtime.entities.message_entities import ( from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent, ImagePromptMessageContent,
@ -55,6 +57,118 @@ class MockTokenBufferMemory:
return self.history_messages return self.history_messages
def test_llm_node_data_normalizes_optional_configs_and_legacy_structured_output() -> None:
node_data = LLMNodeData.model_validate(
{
"title": "Test LLM",
"model": {"provider": "openai", "name": "gpt-4o-mini", "mode": LLMMode.CHAT, "completion_params": {}},
"prompt_template": [],
"prompt_config": None,
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": True, "configs": None},
"structured_output": {
"schema": {"type": "object"},
"name": "Response",
"description": "Structured",
},
"structured_output_enabled": True,
}
)
assert node_data.prompt_config.jinja2_variables == []
assert node_data.vision.configs.variable_selector == ["sys", "files"]
assert node_data.structured_output == {
"schema": {"type": "object"},
"name": "Response",
"description": "Structured",
}
assert node_data.structured_output_enabled is True
def test_llm_node_data_discards_legacy_structured_output_without_schema() -> None:
node_data = LLMNodeData.model_validate(
{
"title": "Test LLM",
"model": {"provider": "openai", "name": "gpt-4o-mini", "mode": LLMMode.CHAT, "completion_params": {}},
"prompt_template": [],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
"structured_output": {"name": "Missing schema"},
"structured_output_enabled": True,
}
)
assert node_data.structured_output is None
assert node_data.structured_output_enabled is False
def test_prompt_config_converts_none_jinja_variables() -> None:
prompt_config = LLMNodeData.model_validate(
{
"title": "Test LLM",
"model": {"provider": "openai", "name": "gpt-4o-mini", "mode": LLMMode.CHAT, "completion_params": {}},
"prompt_template": [],
"prompt_config": None,
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
"structured_output_enabled": False,
}
).prompt_config
assert prompt_config.jinja2_variables == []
def test_fetch_structured_output_schema_validates_required_object_shape() -> None:
assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object", "a": 1}}) == {
"type": "object",
"a": 1,
}
with pytest.raises(Exception, match="valid structured output schema"):
LLMNode.fetch_structured_output_schema(structured_output={"schema": None})
def test_handle_blocking_result_separates_reasoning_and_structured_output() -> None:
saver = mock.MagicMock(spec=LLMFileSaver)
event = LLMNode.handle_blocking_result(
invoke_result=LLMResultWithStructuredOutput(
model="gpt",
message=AssistantPromptMessage(content="<think>reasoning</think>answer"),
usage=LLMUsage.empty_usage(),
structured_output={"answer": "done"},
),
saver=saver,
file_outputs=[],
reasoning_format="separated",
request_latency=1.2345,
)
assert event.text == "answer"
assert event.reasoning_content == "reasoning"
assert event.structured_output == {"answer": "done"}
assert event.usage.latency == 1.234
def test_handle_blocking_result_keeps_tagged_text_without_structured_output() -> None:
saver = mock.MagicMock(spec=LLMFileSaver)
event = LLMNode.handle_blocking_result(
invoke_result=LLMResult(
model="gpt",
message=AssistantPromptMessage(content="plain text"),
usage=LLMUsage.empty_usage(),
),
saver=saver,
file_outputs=[],
)
assert event.text == "plain text"
assert event.reasoning_content == ""
assert event.structured_output is None
@pytest.fixture @pytest.fixture
def llm_node_data() -> LLMNodeData: def llm_node_data() -> LLMNodeData:
return LLMNodeData( return LLMNodeData(

View File

@ -1,6 +1,12 @@
from types import SimpleNamespace
import pytest
from pydantic import ValidationError
from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.nodes.loop.entities import LoopNodeData from dify_graph.nodes.loop.entities import LoopNodeData, LoopValue
from dify_graph.nodes.loop.loop_node import LoopNode from dify_graph.nodes.loop.loop_node import LoopNode
from dify_graph.variables.types import SegmentType
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
@ -50,3 +56,104 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf
) )
assert seen_configs == [child_node_config] assert seen_configs == [child_node_config]
@pytest.mark.parametrize(
("var_type", "original_value", "expected_value"),
[
(SegmentType.ARRAY_STRING, ["alpha", "beta"], ["alpha", "beta"]),
(SegmentType.ARRAY_NUMBER, [1, 2.5], [1, 2.5]),
(SegmentType.ARRAY_OBJECT, [{"name": "item"}], [{"name": "item"}]),
(SegmentType.ARRAY_STRING, '["legacy", "json"]', ["legacy", "json"]),
],
)
def test_get_segment_for_constant_accepts_native_array_values(
var_type: SegmentType, original_value: LoopValue, expected_value: LoopValue
) -> None:
segment = LoopNode._get_segment_for_constant(var_type, original_value)
assert segment.value_type == var_type
assert segment.value == expected_value
def test_loop_variable_data_validates_variable_selector_and_constant_value() -> None:
variable_input = LoopNodeData(
title="Loop",
loop_count=1,
break_conditions=[],
logical_operator="and",
loop_variables=[
{
"label": "question",
"var_type": SegmentType.STRING,
"value_type": "variable",
"value": ["start", "question"],
},
{
"label": "payload",
"var_type": SegmentType.OBJECT,
"value_type": "constant",
"value": {"count": 1, "items": ["a", 2]},
},
],
)
assert variable_input.loop_variables[0].require_variable_selector() == ["start", "question"]
assert variable_input.loop_variables[1].require_constant_value() == {"count": 1, "items": ["a", 2]}
def test_loop_variable_data_rejects_missing_variable_selector() -> None:
with pytest.raises(ValidationError, match="Variable loop inputs require a selector"):
LoopNodeData(
title="Loop",
loop_count=1,
break_conditions=[],
logical_operator="and",
loop_variables=[
{
"label": "question",
"var_type": SegmentType.STRING,
"value_type": "variable",
"value": None,
}
],
)
def test_loop_node_data_outputs_default_to_empty_mapping_for_none() -> None:
node_data = LoopNodeData(
title="Loop",
loop_count=1,
break_conditions=[],
logical_operator="and",
outputs=None,
)
assert node_data.outputs == {}
def test_append_loop_info_to_event_preserves_existing_loop_metadata() -> None:
node = object.__new__(LoopNode)
node._node_id = "loop-node"
event = SimpleNamespace(
node_run_result=SimpleNamespace(metadata={"loop_id": "existing-loop", "other": "value"}),
in_loop_id=None,
)
node._append_loop_info_to_event(event=event, loop_run_index=2)
assert event.in_loop_id == "loop-node"
assert event.node_run_result.metadata == {"loop_id": "existing-loop", "other": "value"}
def test_clear_loop_subgraph_variables_removes_each_loop_node() -> None:
node = object.__new__(LoopNode)
remove_calls: list[list[str]] = []
node.graph_runtime_state = SimpleNamespace(
variable_pool=SimpleNamespace(remove=lambda selector: remove_calls.append(selector))
)
node._clear_loop_subgraph_variables({"child-a", "child-b"})
assert sorted(remove_calls) == [["child-a"], ["child-b"]]

View File

@ -8,11 +8,13 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
from dify_graph.nodes.tool.entities import ToolEntity as WorkflowToolEntity
from dify_graph.nodes.tool.entities import ToolNodeData
from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable from dify_graph.system_variable import SystemVariable
from dify_graph.variables.segments import ArrayFileSegment from dify_graph.variables.segments import ArrayFileSegment
@ -167,3 +169,119 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode):
files_segment = completed_events[0].node_run_result.outputs["files"] files_segment = completed_events[0].node_run_result.outputs["files"]
assert isinstance(files_segment, ArrayFileSegment) assert isinstance(files_segment, ArrayFileSegment)
assert files_segment.value == [] assert files_segment.value == []
def test_workflow_tool_entity_accepts_primitives_and_tool_input_payloads() -> None:
entity = WorkflowToolEntity(
provider_id="provider",
provider_type="builtin",
provider_name="provider",
tool_name="search",
tool_label="Search",
tool_configurations={
"timeout": 30,
"query": {"type": "mixed", "value": "hello {{name}}"},
"selector": {"type": "variable", "value": ["start", "question"]},
},
)
assert entity.tool_configurations == {
"timeout": 30,
"query": {"type": "mixed", "value": "hello {{name}}"},
"selector": {"type": "variable", "value": ["start", "question"]},
}
def test_workflow_tool_entity_rejects_invalid_configuration_entries() -> None:
with pytest.raises(TypeError, match="Tool configuration values must be primitives"):
WorkflowToolEntity(
provider_id="provider",
provider_type="builtin",
provider_name="provider",
tool_name="search",
tool_label="Search",
tool_configurations={"bad": [object()]},
)
def test_tool_node_data_filters_missing_tool_parameter_values() -> None:
node_data = ToolNodeData(
title="Tool",
provider_id="provider",
provider_type="builtin",
provider_name="provider",
tool_name="search",
tool_label="Search",
tool_configurations={},
tool_parameters={
"query": {"type": "mixed", "value": "hello"},
"skip_none": None,
"skip_empty": {"type": "constant", "value": None},
},
)
assert set(node_data.tool_parameters.keys()) == {"query"}
def test_generate_parameters_reads_variables_and_optional_missing_inputs(tool_node: ToolNode) -> None:
variable_pool = MagicMock()
variable_pool.get.side_effect = [MagicMock(value="from-variable"), None]
node_data = ToolNodeData.model_validate(
{
"title": "Tool",
"provider_id": "provider",
"provider_type": "builtin",
"provider_name": "provider",
"tool_name": "tool",
"tool_label": "tool",
"tool_configurations": {},
"tool_parameters": {
"query": {"type": "variable", "value": ["start", "query"]},
"optional": {"type": "variable", "value": ["start", "optional"]},
},
}
)
tool_parameters = [
ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True),
ToolParameter.get_simple_instance("optional", "optional", ToolParameter.ToolParameterType.STRING, False),
]
result = tool_node._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=variable_pool,
node_data=node_data,
)
assert result == {"query": "from-variable"}
def test_generate_parameters_formats_logs_and_unknown_parameters(tool_node: ToolNode) -> None:
variable_pool = MagicMock()
variable_pool.convert_template.return_value = MagicMock(text="rendered", log="masked")
node_data = ToolNodeData.model_validate(
{
"title": "Tool",
"provider_id": "provider",
"provider_type": "builtin",
"provider_name": "provider",
"tool_name": "tool",
"tool_label": "tool",
"tool_configurations": {},
"tool_parameters": {
"query": {"type": "mixed", "value": "{{ question }}"},
"missing": {"type": "constant", "value": "literal"},
},
}
)
tool_parameters = [
ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True),
]
result = tool_node._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=variable_pool,
node_data=node_data,
for_log=True,
)
assert result == {"query": "masked", "missing": None}

View File

@ -97,6 +97,22 @@ class TestWorkflowChildEngineBuilder:
((sentinel.layer_two,), {}), ((sentinel.layer_two,), {}),
] ]
def test_build_child_engine_tolerates_invalid_graph_shape_until_graph_init(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
with (
patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory),
patch.object(workflow_entry.Graph, "init", side_effect=ValueError("invalid graph")),
):
with pytest.raises(ValueError, match="invalid graph"):
builder.build_child_engine(
workflow_id="workflow-id",
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
graph_config={"nodes": "invalid"},
root_node_id="root",
)
class TestWorkflowEntryInit: class TestWorkflowEntryInit:
def test_rejects_call_depth_above_limit(self): def test_rejects_call_depth_above_limit(self):