fix: nested node single step run

This commit is contained in:
Novice
2026-01-28 10:18:10 +08:00
parent a571b3abb2
commit cd688a0d8f
4 changed files with 132 additions and 31 deletions

View File

@ -1402,20 +1402,19 @@ class LLMNode(Node[LLMNodeData]):
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template: (Sequence[LLMNodeChatModelMessage | PromptMessageContext] |
LLMNodeCompletionModelPromptTemplate) = typed_node_data.prompt_template
prompt_template = typed_node_data.prompt_template
variable_selectors = []
prompt_context_selectors: list[Sequence[str]] = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt.text)
for item in prompt_template:
# Check PromptMessageContext first (same order as _parse_prompt_template)
# This extracts value_selector which is used by variable_pool.get(ctx_ref.value_selector)
if isinstance(item, PromptMessageContext):
if len(item.value_selector) >= 2:
prompt_context_selectors.append(item.value_selector)
elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
variable_template_parser = VariableTemplateParser(template=item.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
continue
if isinstance(prompt, PromptMessageContext):
if len(prompt.value_selector) < 2:
continue
prompt_context_selectors.append(prompt.value_selector)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
@ -1452,14 +1451,11 @@ class LLMNode(Node[LLMNodeData]):
enable_jinja = False
if isinstance(prompt_template, list):
for prompt in prompt_template:
if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2":
for item in prompt_template:
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
enable_jinja = True
break
if isinstance(prompt, PromptMessageContext):
prompt_context_selectors.append(prompt.value_selector)
else:
prompt_template: LLMNodeCompletionModelPromptTemplate
enable_jinja = True
if enable_jinja:

View File

@ -495,16 +495,22 @@ class ToolNode(Node[ToolNodeData]):
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
Extract variable selector to variable mapping.
This method extracts:
1. Variable references from tool parameters (mixed, variable types)
2. Output selector from nested_node_config
3. Variable references from nested nodes (nodes with parent_node_id == node_id)
:param graph_config: graph config containing all nodes
:param node_id: current node id
:param node_data: node data
:return:
:return: mapping of variable key to variable selector
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
result = {}
result: dict[str, Sequence[str]] = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
@ -517,13 +523,74 @@ class ToolNode(Node[ToolNodeData]):
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "nested_node":
# Nested node type: value is handled by extractor node, no direct variable reference
pass
# Nested node type: extract variable selector from nested_node_config
# The full selector is extractor_node_id + output_selector
if input.nested_node_config is not None:
config = input.nested_node_config
full_selector = [config.extractor_node_id] + list(config.output_selector)
selector_key = ".".join(full_selector)
result[f"#{selector_key}#"] = full_selector
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
# Extract variable references from nested nodes (nodes with parent_node_id == node_id)
nested_node_mapping = cls._extract_nested_node_variable_mapping(
graph_config=graph_config, parent_node_id=node_id
)
result.update(nested_node_mapping)
return result
@classmethod
def _extract_nested_node_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
parent_node_id: str,
) -> dict[str, Sequence[str]]:
"""
Extract variable references from nested nodes.
Nested nodes are nodes with parent_node_id pointing to the current node.
They are typically extractor LLM nodes that extract values from list[PromptMessage].
:param graph_config: graph config containing all nodes
:param parent_node_id: the parent node id to find nested nodes for
:return: mapping of variable key to variable selector
"""
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
result: dict[str, Sequence[str]] = {}
nodes = graph_config.get("nodes", [])
for node_config in nodes:
node_data = node_config.get("data", {})
# Find nodes that are nested under the parent node
if node_data.get("parent_node_id") != parent_node_id:
continue
nested_node_id = node_config.get("id")
if not nested_node_id:
continue
# Get nested node class and extract its variable references
try:
node_type = NodeType(node_data.get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = node_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
nested_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=node_config
)
result.update(nested_variable_mapping)
except (NotImplementedError, ValueError, KeyError):
# Skip if node type is not found or extraction fails
continue
return result
@property

View File

@ -5,6 +5,13 @@ from uuid import uuid4
from configs import dify_config
from core.file import File
from core.model_runtime.entities import PromptMessage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
@ -214,6 +221,30 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
}
def _deserialize_prompt_message_list(value: list[dict]) -> list[PromptMessage]:
"""
Deserialize a list of dicts to list[PromptMessage].
This is used when loading ARRAY_PROMPT_MESSAGE from database,
where PromptMessage objects are serialized as dicts.
"""
result: list[PromptMessage] = []
for msg_dict in value:
role = msg_dict.get("role")
if role in (PromptMessageRole.USER, "user"):
result.append(UserPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
result.append(AssistantPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.SYSTEM, "system"):
result.append(SystemPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.TOOL, "tool"):
result.append(ToolPromptMessage.model_validate(msg_dict))
else:
# Fallback to UserPromptMessage for unknown roles
result.append(UserPromptMessage.model_validate(msg_dict))
return result
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
"""
Build a segment with explicit type checking.
@ -287,8 +318,10 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
return segment_class(value_type=inferred_type, value=value)
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
# Need to deserialize dict list back to PromptMessage objects
deserialized_messages = _deserialize_prompt_message_list(value)
segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=value)
return segment_class(value_type=segment_type, value=deserialized_messages)
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")

View File

@ -907,14 +907,19 @@ class WorkflowService:
"""
try:
node, node_events = invoke_node_fn()
node_run_result = next(
(
event.node_run_result
for event in node_events
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent))
),
None,
)
# Collect all events to find the appropriate result:
# - For failure: take the first NodeRunFailedEvent (fail fast)
# - For success: take the last NodeRunSucceededEvent (parent node result after nested nodes)
events_list = list(node_events)
node_run_result = None
for event in events_list:
if isinstance(event, NodeRunFailedEvent):
# Take first failure and stop
node_run_result = event.node_run_result
break
elif isinstance(event, NodeRunSucceededEvent):
# Keep updating to get the last success
node_run_result = event.node_run_result
if not node_run_result:
raise ValueError("Node execution failed - no result returned")