mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
fix: nested node single step run
This commit is contained in:
@ -1402,20 +1402,19 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
# Create typed NodeData from dict
|
# Create typed NodeData from dict
|
||||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||||
|
|
||||||
prompt_template: (Sequence[LLMNodeChatModelMessage | PromptMessageContext] |
|
prompt_template = typed_node_data.prompt_template
|
||||||
LLMNodeCompletionModelPromptTemplate) = typed_node_data.prompt_template
|
|
||||||
variable_selectors = []
|
variable_selectors = []
|
||||||
prompt_context_selectors: list[Sequence[str]] = []
|
prompt_context_selectors: list[Sequence[str]] = []
|
||||||
if isinstance(prompt_template, list):
|
if isinstance(prompt_template, list):
|
||||||
for prompt in prompt_template:
|
for item in prompt_template:
|
||||||
if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2":
|
# Check PromptMessageContext first (same order as _parse_prompt_template)
|
||||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
# This extracts value_selector which is used by variable_pool.get(ctx_ref.value_selector)
|
||||||
|
if isinstance(item, PromptMessageContext):
|
||||||
|
if len(item.value_selector) >= 2:
|
||||||
|
prompt_context_selectors.append(item.value_selector)
|
||||||
|
elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||||
|
variable_template_parser = VariableTemplateParser(template=item.text)
|
||||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||||
continue
|
|
||||||
if isinstance(prompt, PromptMessageContext):
|
|
||||||
if len(prompt.value_selector) < 2:
|
|
||||||
continue
|
|
||||||
prompt_context_selectors.append(prompt.value_selector)
|
|
||||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||||
if prompt_template.edition_type != "jinja2":
|
if prompt_template.edition_type != "jinja2":
|
||||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||||
@ -1452,14 +1451,11 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
enable_jinja = False
|
enable_jinja = False
|
||||||
|
|
||||||
if isinstance(prompt_template, list):
|
if isinstance(prompt_template, list):
|
||||||
for prompt in prompt_template:
|
for item in prompt_template:
|
||||||
if isinstance(prompt, LLMNodeChatModelMessage) and prompt.edition_type == "jinja2":
|
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||||
enable_jinja = True
|
enable_jinja = True
|
||||||
break
|
break
|
||||||
if isinstance(prompt, PromptMessageContext):
|
|
||||||
prompt_context_selectors.append(prompt.value_selector)
|
|
||||||
else:
|
else:
|
||||||
prompt_template: LLMNodeCompletionModelPromptTemplate
|
|
||||||
enable_jinja = True
|
enable_jinja = True
|
||||||
|
|
||||||
if enable_jinja:
|
if enable_jinja:
|
||||||
|
|||||||
@ -495,16 +495,22 @@ class ToolNode(Node[ToolNodeData]):
|
|||||||
node_data: Mapping[str, Any],
|
node_data: Mapping[str, Any],
|
||||||
) -> Mapping[str, Sequence[str]]:
|
) -> Mapping[str, Sequence[str]]:
|
||||||
"""
|
"""
|
||||||
Extract variable selector to variable mapping
|
Extract variable selector to variable mapping.
|
||||||
:param graph_config: graph config
|
|
||||||
:param node_id: node id
|
This method extracts:
|
||||||
|
1. Variable references from tool parameters (mixed, variable types)
|
||||||
|
2. Output selector from nested_node_config
|
||||||
|
3. Variable references from nested nodes (nodes with parent_node_id == node_id)
|
||||||
|
|
||||||
|
:param graph_config: graph config containing all nodes
|
||||||
|
:param node_id: current node id
|
||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
:return:
|
:return: mapping of variable key to variable selector
|
||||||
"""
|
"""
|
||||||
# Create typed NodeData from dict
|
# Create typed NodeData from dict
|
||||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||||
|
|
||||||
result = {}
|
result: dict[str, Sequence[str]] = {}
|
||||||
for parameter_name in typed_node_data.tool_parameters:
|
for parameter_name in typed_node_data.tool_parameters:
|
||||||
input = typed_node_data.tool_parameters[parameter_name]
|
input = typed_node_data.tool_parameters[parameter_name]
|
||||||
if input.type == "mixed":
|
if input.type == "mixed":
|
||||||
@ -517,13 +523,74 @@ class ToolNode(Node[ToolNodeData]):
|
|||||||
selector_key = ".".join(input.value)
|
selector_key = ".".join(input.value)
|
||||||
result[f"#{selector_key}#"] = input.value
|
result[f"#{selector_key}#"] = input.value
|
||||||
elif input.type == "nested_node":
|
elif input.type == "nested_node":
|
||||||
# Nested node type: value is handled by extractor node, no direct variable reference
|
# Nested node type: extract variable selector from nested_node_config
|
||||||
pass
|
# The full selector is extractor_node_id + output_selector
|
||||||
|
if input.nested_node_config is not None:
|
||||||
|
config = input.nested_node_config
|
||||||
|
full_selector = [config.extractor_node_id] + list(config.output_selector)
|
||||||
|
selector_key = ".".join(full_selector)
|
||||||
|
result[f"#{selector_key}#"] = full_selector
|
||||||
elif input.type == "constant":
|
elif input.type == "constant":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
result = {node_id + "." + key: value for key, value in result.items()}
|
result = {node_id + "." + key: value for key, value in result.items()}
|
||||||
|
|
||||||
|
# Extract variable references from nested nodes (nodes with parent_node_id == node_id)
|
||||||
|
nested_node_mapping = cls._extract_nested_node_variable_mapping(
|
||||||
|
graph_config=graph_config, parent_node_id=node_id
|
||||||
|
)
|
||||||
|
result.update(nested_node_mapping)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_nested_node_variable_mapping(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
parent_node_id: str,
|
||||||
|
) -> dict[str, Sequence[str]]:
|
||||||
|
"""
|
||||||
|
Extract variable references from nested nodes.
|
||||||
|
|
||||||
|
Nested nodes are nodes with parent_node_id pointing to the current node.
|
||||||
|
They are typically extractor LLM nodes that extract values from list[PromptMessage].
|
||||||
|
|
||||||
|
:param graph_config: graph config containing all nodes
|
||||||
|
:param parent_node_id: the parent node id to find nested nodes for
|
||||||
|
:return: mapping of variable key to variable selector
|
||||||
|
"""
|
||||||
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
|
|
||||||
|
result: dict[str, Sequence[str]] = {}
|
||||||
|
nodes = graph_config.get("nodes", [])
|
||||||
|
|
||||||
|
for node_config in nodes:
|
||||||
|
node_data = node_config.get("data", {})
|
||||||
|
# Find nodes that are nested under the parent node
|
||||||
|
if node_data.get("parent_node_id") != parent_node_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
nested_node_id = node_config.get("id")
|
||||||
|
if not nested_node_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get nested node class and extract its variable references
|
||||||
|
try:
|
||||||
|
node_type = NodeType(node_data.get("type"))
|
||||||
|
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||||
|
continue
|
||||||
|
node_version = node_data.get("version", "1")
|
||||||
|
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||||
|
|
||||||
|
nested_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||||
|
graph_config=graph_config, config=node_config
|
||||||
|
)
|
||||||
|
result.update(nested_variable_mapping)
|
||||||
|
except (NotImplementedError, ValueError, KeyError):
|
||||||
|
# Skip if node type is not found or extraction fails
|
||||||
|
continue
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -5,6 +5,13 @@ from uuid import uuid4
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.model_runtime.entities import PromptMessage
|
from core.model_runtime.entities import PromptMessage
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
from core.variables.exc import VariableError
|
from core.variables.exc import VariableError
|
||||||
from core.variables.segments import (
|
from core.variables.segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
@ -214,6 +221,30 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_prompt_message_list(value: list[dict]) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Deserialize a list of dicts to list[PromptMessage].
|
||||||
|
|
||||||
|
This is used when loading ARRAY_PROMPT_MESSAGE from database,
|
||||||
|
where PromptMessage objects are serialized as dicts.
|
||||||
|
"""
|
||||||
|
result: list[PromptMessage] = []
|
||||||
|
for msg_dict in value:
|
||||||
|
role = msg_dict.get("role")
|
||||||
|
if role in (PromptMessageRole.USER, "user"):
|
||||||
|
result.append(UserPromptMessage.model_validate(msg_dict))
|
||||||
|
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||||
|
result.append(AssistantPromptMessage.model_validate(msg_dict))
|
||||||
|
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||||
|
result.append(SystemPromptMessage.model_validate(msg_dict))
|
||||||
|
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||||
|
result.append(ToolPromptMessage.model_validate(msg_dict))
|
||||||
|
else:
|
||||||
|
# Fallback to UserPromptMessage for unknown roles
|
||||||
|
result.append(UserPromptMessage.model_validate(msg_dict))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||||
"""
|
"""
|
||||||
Build a segment with explicit type checking.
|
Build a segment with explicit type checking.
|
||||||
@ -287,8 +318,10 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
|||||||
return segment_class(value_type=inferred_type, value=value)
|
return segment_class(value_type=inferred_type, value=value)
|
||||||
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
|
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
|
||||||
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
|
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
|
||||||
|
# Need to deserialize dict list back to PromptMessage objects
|
||||||
|
deserialized_messages = _deserialize_prompt_message_list(value)
|
||||||
segment_class = _segment_factory[segment_type]
|
segment_class = _segment_factory[segment_type]
|
||||||
return segment_class(value_type=segment_type, value=value)
|
return segment_class(value_type=segment_type, value=deserialized_messages)
|
||||||
else:
|
else:
|
||||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
|
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
|
||||||
|
|
||||||
|
|||||||
@ -907,14 +907,19 @@ class WorkflowService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
node, node_events = invoke_node_fn()
|
node, node_events = invoke_node_fn()
|
||||||
node_run_result = next(
|
# Collect all events to find the appropriate result:
|
||||||
(
|
# - For failure: take the first NodeRunFailedEvent (fail fast)
|
||||||
event.node_run_result
|
# - For success: take the last NodeRunSucceededEvent (parent node result after nested nodes)
|
||||||
for event in node_events
|
events_list = list(node_events)
|
||||||
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent))
|
node_run_result = None
|
||||||
),
|
for event in events_list:
|
||||||
None,
|
if isinstance(event, NodeRunFailedEvent):
|
||||||
)
|
# Take first failure and stop
|
||||||
|
node_run_result = event.node_run_result
|
||||||
|
break
|
||||||
|
elif isinstance(event, NodeRunSucceededEvent):
|
||||||
|
# Keep updating to get the last success
|
||||||
|
node_run_result = event.node_run_result
|
||||||
|
|
||||||
if not node_run_result:
|
if not node_run_result:
|
||||||
raise ValueError("Node execution failed - no result returned")
|
raise ValueError("Node execution failed - no result returned")
|
||||||
|
|||||||
Reference in New Issue
Block a user