mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
fix: nested node single step run
This commit is contained in:
@ -5,6 +5,13 @@ from uuid import uuid4
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
@ -214,6 +221,30 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_prompt_message_list(value: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize a list of dicts to list[PromptMessage].
|
||||
|
||||
This is used when loading ARRAY_PROMPT_MESSAGE from database,
|
||||
where PromptMessage objects are serialized as dicts.
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg_dict in value:
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
result.append(UserPromptMessage.model_validate(msg_dict))
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
result.append(AssistantPromptMessage.model_validate(msg_dict))
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
result.append(SystemPromptMessage.model_validate(msg_dict))
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
result.append(ToolPromptMessage.model_validate(msg_dict))
|
||||
else:
|
||||
# Fallback to UserPromptMessage for unknown roles
|
||||
result.append(UserPromptMessage.model_validate(msg_dict))
|
||||
return result
|
||||
|
||||
|
||||
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
"""
|
||||
Build a segment with explicit type checking.
|
||||
@ -287,8 +318,10 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
return segment_class(value_type=inferred_type, value=value)
|
||||
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
|
||||
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
|
||||
# Need to deserialize dict list back to PromptMessage objects
|
||||
deserialized_messages = _deserialize_prompt_message_list(value)
|
||||
segment_class = _segment_factory[segment_type]
|
||||
return segment_class(value_type=segment_type, value=value)
|
||||
return segment_class(value_type=segment_type, value=deserialized_messages)
|
||||
else:
|
||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user