mirror of
https://github.com/langgenius/dify.git
synced 2026-03-16 04:17:43 +08:00
fix: nested node single step run
This commit is contained in:
@ -1402,20 +1402,19 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||
|
||||
prompt_template: (Sequence[LLMNodeChatModelMessage | PromptMessageContext] |
|
||||
LLMNodeCompletionModelPromptTemplate) = typed_node_data.prompt_template
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
variable_selectors = []
|
||||
prompt_context_selectors: list[Sequence[str]] = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
for item in prompt_template:
|
||||
# Check PromptMessageContext first (same order as _parse_prompt_template)
|
||||
# This extracts value_selector which is used by variable_pool.get(ctx_ref.value_selector)
|
||||
if isinstance(item, PromptMessageContext):
|
||||
if len(item.value_selector) >= 2:
|
||||
prompt_context_selectors.append(item.value_selector)
|
||||
elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=item.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
continue
|
||||
if isinstance(prompt, PromptMessageContext):
|
||||
if len(prompt.value_selector) < 2:
|
||||
continue
|
||||
prompt_context_selectors.append(prompt.value_selector)
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
if prompt_template.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
@ -1452,14 +1451,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2":
|
||||
for item in prompt_template:
|
||||
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
break
|
||||
if isinstance(prompt, PromptMessageContext):
|
||||
prompt_context_selectors.append(prompt.value_selector)
|
||||
else:
|
||||
prompt_template: LLMNodeCompletionModelPromptTemplate
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
|
||||
@ -495,16 +495,22 @@ class ToolNode(Node[ToolNodeData]):
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
Extract variable selector to variable mapping.
|
||||
|
||||
This method extracts:
|
||||
1. Variable references from tool parameters (mixed, variable types)
|
||||
2. Output selector from nested_node_config
|
||||
3. Variable references from nested nodes (nodes with parent_node_id == node_id)
|
||||
|
||||
:param graph_config: graph config containing all nodes
|
||||
:param node_id: current node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
:return: mapping of variable key to variable selector
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||
|
||||
result = {}
|
||||
result: dict[str, Sequence[str]] = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
@ -517,13 +523,74 @@ class ToolNode(Node[ToolNodeData]):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "nested_node":
|
||||
# Nested node type: value is handled by extractor node, no direct variable reference
|
||||
pass
|
||||
# Nested node type: extract variable selector from nested_node_config
|
||||
# The full selector is extractor_node_id + output_selector
|
||||
if input.nested_node_config is not None:
|
||||
config = input.nested_node_config
|
||||
full_selector = [config.extractor_node_id] + list(config.output_selector)
|
||||
selector_key = ".".join(full_selector)
|
||||
result[f"#{selector_key}#"] = full_selector
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
# Extract variable references from nested nodes (nodes with parent_node_id == node_id)
|
||||
nested_node_mapping = cls._extract_nested_node_variable_mapping(
|
||||
graph_config=graph_config, parent_node_id=node_id
|
||||
)
|
||||
result.update(nested_node_mapping)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _extract_nested_node_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
parent_node_id: str,
|
||||
) -> dict[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable references from nested nodes.
|
||||
|
||||
Nested nodes are nodes with parent_node_id pointing to the current node.
|
||||
They are typically extractor LLM nodes that extract values from list[PromptMessage].
|
||||
|
||||
:param graph_config: graph config containing all nodes
|
||||
:param parent_node_id: the parent node id to find nested nodes for
|
||||
:return: mapping of variable key to variable selector
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
result: dict[str, Sequence[str]] = {}
|
||||
nodes = graph_config.get("nodes", [])
|
||||
|
||||
for node_config in nodes:
|
||||
node_data = node_config.get("data", {})
|
||||
# Find nodes that are nested under the parent node
|
||||
if node_data.get("parent_node_id") != parent_node_id:
|
||||
continue
|
||||
|
||||
nested_node_id = node_config.get("id")
|
||||
if not nested_node_id:
|
||||
continue
|
||||
|
||||
# Get nested node class and extract its variable references
|
||||
try:
|
||||
node_type = NodeType(node_data.get("type"))
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = node_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
nested_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=node_config
|
||||
)
|
||||
result.update(nested_variable_mapping)
|
||||
except (NotImplementedError, ValueError, KeyError):
|
||||
# Skip if node type is not found or extraction fails
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
|
||||
@ -5,6 +5,13 @@ from uuid import uuid4
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
@ -214,6 +221,30 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_prompt_message_list(value: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize a list of dicts to list[PromptMessage].
|
||||
|
||||
This is used when loading ARRAY_PROMPT_MESSAGE from database,
|
||||
where PromptMessage objects are serialized as dicts.
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg_dict in value:
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
result.append(UserPromptMessage.model_validate(msg_dict))
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
result.append(AssistantPromptMessage.model_validate(msg_dict))
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
result.append(SystemPromptMessage.model_validate(msg_dict))
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
result.append(ToolPromptMessage.model_validate(msg_dict))
|
||||
else:
|
||||
# Fallback to UserPromptMessage for unknown roles
|
||||
result.append(UserPromptMessage.model_validate(msg_dict))
|
||||
return result
|
||||
|
||||
|
||||
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
"""
|
||||
Build a segment with explicit type checking.
|
||||
@ -287,8 +318,10 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
return segment_class(value_type=inferred_type, value=value)
|
||||
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
|
||||
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
|
||||
# Need to deserialize dict list back to PromptMessage objects
|
||||
deserialized_messages = _deserialize_prompt_message_list(value)
|
||||
segment_class = _segment_factory[segment_type]
|
||||
return segment_class(value_type=segment_type, value=value)
|
||||
return segment_class(value_type=segment_type, value=deserialized_messages)
|
||||
else:
|
||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
|
||||
|
||||
|
||||
@ -907,14 +907,19 @@ class WorkflowService:
|
||||
"""
|
||||
try:
|
||||
node, node_events = invoke_node_fn()
|
||||
node_run_result = next(
|
||||
(
|
||||
event.node_run_result
|
||||
for event in node_events
|
||||
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent))
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Collect all events to find the appropriate result:
|
||||
# - For failure: take the first NodeRunFailedEvent (fail fast)
|
||||
# - For success: take the last NodeRunSucceededEvent (parent node result after nested nodes)
|
||||
events_list = list(node_events)
|
||||
node_run_result = None
|
||||
for event in events_list:
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
# Take first failure and stop
|
||||
node_run_result = event.node_run_result
|
||||
break
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
# Keep updating to get the last success
|
||||
node_run_result = event.node_run_result
|
||||
|
||||
if not node_run_result:
|
||||
raise ValueError("Node execution failed - no result returned")
|
||||
|
||||
Reference in New Issue
Block a user