fix: nested node single step run

This commit is contained in:
Novice
2026-01-28 10:18:10 +08:00
parent a571b3abb2
commit cd688a0d8f
4 changed files with 132 additions and 31 deletions

View File

@ -1402,20 +1402,19 @@ class LLMNode(Node[LLMNodeData]):
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data) typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template: (Sequence[LLMNodeChatModelMessage | PromptMessageContext] | prompt_template = typed_node_data.prompt_template
LLMNodeCompletionModelPromptTemplate) = typed_node_data.prompt_template
variable_selectors = [] variable_selectors = []
prompt_context_selectors: list[Sequence[str]] = [] prompt_context_selectors: list[Sequence[str]] = []
if isinstance(prompt_template, list): if isinstance(prompt_template, list):
for prompt in prompt_template: for item in prompt_template:
if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2": # Check PromptMessageContext first (same order as _parse_prompt_template)
variable_template_parser = VariableTemplateParser(template=prompt.text) # This extracts value_selector which is used by variable_pool.get(ctx_ref.value_selector)
if isinstance(item, PromptMessageContext):
if len(item.value_selector) >= 2:
prompt_context_selectors.append(item.value_selector)
elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
variable_template_parser = VariableTemplateParser(template=item.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors()) variable_selectors.extend(variable_template_parser.extract_variable_selectors())
continue
if isinstance(prompt, PromptMessageContext):
if len(prompt.value_selector) < 2:
continue
prompt_context_selectors.append(prompt.value_selector)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2": if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_template_parser = VariableTemplateParser(template=prompt_template.text)
@ -1452,14 +1451,11 @@ class LLMNode(Node[LLMNodeData]):
enable_jinja = False enable_jinja = False
if isinstance(prompt_template, list): if isinstance(prompt_template, list):
for prompt in prompt_template: for item in prompt_template:
if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2": if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
enable_jinja = True enable_jinja = True
break break
if isinstance(prompt, PromptMessageContext):
prompt_context_selectors.append(prompt.value_selector)
else: else:
prompt_template: LLMNodeCompletionModelPromptTemplate
enable_jinja = True enable_jinja = True
if enable_jinja: if enable_jinja:

View File

@ -495,16 +495,22 @@ class ToolNode(Node[ToolNodeData]):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping.
:param graph_config: graph config
:param node_id: node id This method extracts:
1. Variable references from tool parameters (mixed, variable types)
2. Output selector from nested_node_config
3. Variable references from nested nodes (nodes with parent_node_id == node_id)
:param graph_config: graph config containing all nodes
:param node_id: current node id
:param node_data: node data :param node_data: node data
:return: :return: mapping of variable key to variable selector
""" """
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data) typed_node_data = ToolNodeData.model_validate(node_data)
result = {} result: dict[str, Sequence[str]] = {}
for parameter_name in typed_node_data.tool_parameters: for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name] input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed": if input.type == "mixed":
@ -517,13 +523,74 @@ class ToolNode(Node[ToolNodeData]):
selector_key = ".".join(input.value) selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value result[f"#{selector_key}#"] = input.value
elif input.type == "nested_node": elif input.type == "nested_node":
# Nested node type: value is handled by extractor node, no direct variable reference # Nested node type: extract variable selector from nested_node_config
pass # The full selector is extractor_node_id + output_selector
if input.nested_node_config is not None:
config = input.nested_node_config
full_selector = [config.extractor_node_id] + list(config.output_selector)
selector_key = ".".join(full_selector)
result[f"#{selector_key}#"] = full_selector
elif input.type == "constant": elif input.type == "constant":
pass pass
result = {node_id + "." + key: value for key, value in result.items()} result = {node_id + "." + key: value for key, value in result.items()}
# Extract variable references from nested nodes (nodes with parent_node_id == node_id)
nested_node_mapping = cls._extract_nested_node_variable_mapping(
graph_config=graph_config, parent_node_id=node_id
)
result.update(nested_node_mapping)
return result
@classmethod
def _extract_nested_node_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
parent_node_id: str,
) -> dict[str, Sequence[str]]:
"""
Extract variable references from nested nodes.
Nested nodes are nodes with parent_node_id pointing to the current node.
They are typically extractor LLM nodes that extract values from list[PromptMessage].
:param graph_config: graph config containing all nodes
:param parent_node_id: the parent node id to find nested nodes for
:return: mapping of variable key to variable selector
"""
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
result: dict[str, Sequence[str]] = {}
nodes = graph_config.get("nodes", [])
for node_config in nodes:
node_data = node_config.get("data", {})
# Find nodes that are nested under the parent node
if node_data.get("parent_node_id") != parent_node_id:
continue
nested_node_id = node_config.get("id")
if not nested_node_id:
continue
# Get nested node class and extract its variable references
try:
node_type = NodeType(node_data.get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = node_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
nested_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=node_config
)
result.update(nested_variable_mapping)
except (NotImplementedError, ValueError, KeyError):
# Skip if node type is not found or extraction fails
continue
return result return result
@property @property

View File

@ -5,6 +5,13 @@ from uuid import uuid4
from configs import dify_config from configs import dify_config
from core.file import File from core.file import File
from core.model_runtime.entities import PromptMessage from core.model_runtime.entities import PromptMessage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.variables.exc import VariableError from core.variables.exc import VariableError
from core.variables.segments import ( from core.variables.segments import (
ArrayAnySegment, ArrayAnySegment,
@ -214,6 +221,30 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
} }
def _deserialize_prompt_message_list(value: list[dict]) -> list[PromptMessage]:
"""
Deserialize a list of dicts to list[PromptMessage].
This is used when loading ARRAY_PROMPT_MESSAGE from database,
where PromptMessage objects are serialized as dicts.
"""
result: list[PromptMessage] = []
for msg_dict in value:
role = msg_dict.get("role")
if role in (PromptMessageRole.USER, "user"):
result.append(UserPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
result.append(AssistantPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.SYSTEM, "system"):
result.append(SystemPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.TOOL, "tool"):
result.append(ToolPromptMessage.model_validate(msg_dict))
else:
# Fallback to UserPromptMessage for unknown roles
result.append(UserPromptMessage.model_validate(msg_dict))
return result
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
""" """
Build a segment with explicit type checking. Build a segment with explicit type checking.
@ -287,8 +318,10 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
return segment_class(value_type=inferred_type, value=value) return segment_class(value_type=inferred_type, value=value)
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT: elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE # PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
# Need to deserialize dict list back to PromptMessage objects
deserialized_messages = _deserialize_prompt_message_list(value)
segment_class = _segment_factory[segment_type] segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=value) return segment_class(value_type=segment_type, value=deserialized_messages)
else: else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")

View File

@ -907,14 +907,19 @@ class WorkflowService:
""" """
try: try:
node, node_events = invoke_node_fn() node, node_events = invoke_node_fn()
node_run_result = next( # Collect all events to find the appropriate result:
( # - For failure: take the first NodeRunFailedEvent (fail fast)
event.node_run_result # - For success: take the last NodeRunSucceededEvent (parent node result after nested nodes)
for event in node_events events_list = list(node_events)
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)) node_run_result = None
), for event in events_list:
None, if isinstance(event, NodeRunFailedEvent):
) # Take first failure and stop
node_run_result = event.node_run_result
break
elif isinstance(event, NodeRunSucceededEvent):
# Keep updating to get the last success
node_run_result = event.node_run_result
if not node_run_result: if not node_run_result:
raise ValueError("Node execution failed - no result returned") raise ValueError("Node execution failed - no result returned")