From cd688a0d8fa864db3866dc1f76ebf8865abfbd22 Mon Sep 17 00:00:00 2001 From: Novice Date: Wed, 28 Jan 2026 10:18:10 +0800 Subject: [PATCH] fix: nested node single step run --- api/core/workflow/nodes/llm/node.py | 26 +++----- api/core/workflow/nodes/tool/tool_node.py | 81 +++++++++++++++++++++-- api/factories/variable_factory.py | 35 +++++++++- api/services/workflow_service.py | 21 +++--- 4 files changed, 132 insertions(+), 31 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 924be1a2e2..5e33830ff9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1402,20 +1402,19 @@ class LLMNode(Node[LLMNodeData]): # Create typed NodeData from dict typed_node_data = LLMNodeData.model_validate(node_data) - prompt_template: (Sequence[LLMNodeChatModelMessage | PromptMessageContext] | - LLMNodeCompletionModelPromptTemplate) = typed_node_data.prompt_template + prompt_template = typed_node_data.prompt_template variable_selectors = [] prompt_context_selectors: list[Sequence[str]] = [] if isinstance(prompt_template, list): - for prompt in prompt_template: - if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt.text) + for item in prompt_template: + # Check PromptMessageContext first (same order as _parse_prompt_template) + # This extracts value_selector which is used by variable_pool.get(ctx_ref.value_selector) + if isinstance(item, PromptMessageContext): + if len(item.value_selector) >= 2: + prompt_context_selectors.append(item.value_selector) + elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2": + variable_template_parser = VariableTemplateParser(template=item.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - continue - if isinstance(prompt, PromptMessageContext): - if len(prompt.value_selector) < 2: - continue - prompt_context_selectors.append(prompt.value_selector) elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) @@ -1452,14 +1451,11 @@ class LLMNode(Node[LLMNodeData]): enable_jinja = False if isinstance(prompt_template, list): - for prompt in prompt_template: - if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2": + for item in prompt_template: + if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2": enable_jinja = True break - if isinstance(prompt, PromptMessageContext): - prompt_context_selectors.append(prompt.value_selector) else: - prompt_template: LLMNodeCompletionModelPromptTemplate enable_jinja = True if enable_jinja: diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 9036edbb59..1b53af35d6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -495,16 +495,22 @@ class ToolNode(Node[ToolNodeData]): node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id + Extract variable selector to variable mapping. + + This method extracts: + 1. Variable references from tool parameters (mixed, variable types) + 2. Output selector from nested_node_config + 3. Variable references from nested nodes (nodes with parent_node_id == node_id) + + :param graph_config: graph config containing all nodes + :param node_id: current node id :param node_data: node data - :return: + :return: mapping of variable key to variable selector """ # Create typed NodeData from dict typed_node_data = ToolNodeData.model_validate(node_data) - result = {} + result: dict[str, Sequence[str]] = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] if input.type == "mixed": @@ -517,13 +523,74 @@ class ToolNode(Node[ToolNodeData]): selector_key = ".".join(input.value) result[f"#{selector_key}#"] = input.value elif input.type == "nested_node": - # Nested node type: value is handled by extractor node, no direct variable reference - pass + # Nested node type: extract variable selector from nested_node_config + # The full selector is extractor_node_id + output_selector + if input.nested_node_config is not None: + config = input.nested_node_config + full_selector = [config.extractor_node_id] + list(config.output_selector) + selector_key = ".".join(full_selector) + result[f"#{selector_key}#"] = full_selector elif input.type == "constant": pass result = {node_id + "." + key: value for key, value in result.items()} + # Extract variable references from nested nodes (nodes with parent_node_id == node_id) + nested_node_mapping = cls._extract_nested_node_variable_mapping( + graph_config=graph_config, parent_node_id=node_id + ) + result.update(nested_node_mapping) + + return result + + @classmethod + def _extract_nested_node_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + parent_node_id: str, + ) -> dict[str, Sequence[str]]: + """ + Extract variable references from nested nodes. + + Nested nodes are nodes with parent_node_id pointing to the current node. + They are typically extractor LLM nodes that extract values from list[PromptMessage]. + + :param graph_config: graph config containing all nodes + :param parent_node_id: the parent node id to find nested nodes for + :return: mapping of variable key to variable selector + """ + from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + + result: dict[str, Sequence[str]] = {} + nodes = graph_config.get("nodes", []) + + for node_config in nodes: + node_data = node_config.get("data", {}) + # Find nodes that are nested under the parent node + if node_data.get("parent_node_id") != parent_node_id: + continue + + nested_node_id = node_config.get("id") + if not nested_node_id: + continue + + # Get nested node class and extract its variable references + try: + node_type = NodeType(node_data.get("type")) + if node_type not in NODE_TYPE_CLASSES_MAPPING: + continue + node_version = node_data.get("version", "1") + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + + nested_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=graph_config, config=node_config + ) + result.update(nested_variable_mapping) + except (NotImplementedError, ValueError, KeyError): + # Skip if node type is not found or extraction fails + continue + return result @property diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 17cbb9cfdd..82408f81f7 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -5,6 +5,13 @@ from uuid import uuid4 from configs import dify_config from core.file import File from core.model_runtime.entities import PromptMessage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageRole, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) from core.variables.exc import VariableError from core.variables.segments import ( ArrayAnySegment, @@ -214,6 +221,30 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = { } +def _deserialize_prompt_message_list(value: list[dict]) -> list[PromptMessage]: + """ + Deserialize a list of dicts to list[PromptMessage]. + + This is used when loading ARRAY_PROMPT_MESSAGE from database, + where PromptMessage objects are serialized as dicts. + """ + result: list[PromptMessage] = [] + for msg_dict in value: + role = msg_dict.get("role") + if role in (PromptMessageRole.USER, "user"): + result.append(UserPromptMessage.model_validate(msg_dict)) + elif role in (PromptMessageRole.ASSISTANT, "assistant"): + result.append(AssistantPromptMessage.model_validate(msg_dict)) + elif role in (PromptMessageRole.SYSTEM, "system"): + result.append(SystemPromptMessage.model_validate(msg_dict)) + elif role in (PromptMessageRole.TOOL, "tool"): + result.append(ToolPromptMessage.model_validate(msg_dict)) + else: + # Fallback to UserPromptMessage for unknown roles + result.append(UserPromptMessage.model_validate(msg_dict)) + return result + + def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: """ Build a segment with explicit type checking. @@ -287,8 +318,10 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: return segment_class(value_type=inferred_type, value=value) elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT: # PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE + # Need to deserialize dict list back to PromptMessage objects + deserialized_messages = _deserialize_prompt_message_list(value) segment_class = _segment_factory[segment_type] - return segment_class(value_type=segment_type, value=value) + return segment_class(value_type=segment_type, value=deserialized_messages) else: raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index fa86772dfc..7e9605f6d3 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -907,14 +907,19 @@ class WorkflowService: """ try: node, node_events = invoke_node_fn() - node_run_result = next( - ( - event.node_run_result - for event in node_events - if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)) - ), - None, - ) + # Collect all events to find the appropriate result: + # - For failure: take the first NodeRunFailedEvent (fail fast) + # - For success: take the last NodeRunSucceededEvent (parent node result after nested nodes) + events_list = list(node_events) + node_run_result = None + for event in events_list: + if isinstance(event, NodeRunFailedEvent): + # Take first failure and stop + node_run_result = event.node_run_result + break + elif isinstance(event, NodeRunSucceededEvent): + # Keep updating to get the last success + node_run_result = event.node_run_result if not node_run_result: raise ValueError("Node execution failed - no result returned")