mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 23:36:20 +08:00
73 lines
2.7 KiB
Python
73 lines
2.7 KiB
Python
from collections.abc import Mapping, Sequence
|
|
from typing import Any
|
|
|
|
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
|
from dify_graph.node_events import NodeRunResult
|
|
from dify_graph.nodes.answer.entities import AnswerNodeData
|
|
from dify_graph.nodes.base.node import Node
|
|
from dify_graph.nodes.base.template import Template
|
|
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
|
from dify_graph.variables import ArrayFileSegment, FileSegment, Segment
|
|
|
|
|
|
class AnswerNode(Node[AnswerNodeData]):
|
|
node_type = NodeType.ANSWER
|
|
execution_type = NodeExecutionType.RESPONSE
|
|
|
|
@classmethod
|
|
def version(cls) -> str:
|
|
return "1"
|
|
|
|
def _run(self) -> NodeRunResult:
|
|
segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer)
|
|
files = self._extract_files_from_segments(segments.value)
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
|
|
)
|
|
|
|
def _extract_files_from_segments(self, segments: Sequence[Segment]):
|
|
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
|
|
|
|
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
|
|
This method flattens all files into a single list.
|
|
"""
|
|
files = []
|
|
for segment in segments:
|
|
if isinstance(segment, FileSegment):
|
|
# Single file - wrap in list for consistency
|
|
files.append(segment.value)
|
|
elif isinstance(segment, ArrayFileSegment):
|
|
# Multiple files - extend the list
|
|
files.extend(segment.value)
|
|
return files
|
|
|
|
@classmethod
|
|
def _extract_variable_selector_to_variable_mapping(
|
|
cls,
|
|
*,
|
|
graph_config: Mapping[str, Any],
|
|
node_id: str,
|
|
node_data: Mapping[str, Any],
|
|
) -> Mapping[str, Sequence[str]]:
|
|
# Create typed NodeData from dict
|
|
typed_node_data = AnswerNodeData.model_validate(node_data)
|
|
|
|
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
|
variable_selectors = variable_template_parser.extract_variable_selectors()
|
|
|
|
variable_mapping = {}
|
|
for variable_selector in variable_selectors:
|
|
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
|
|
|
return variable_mapping
|
|
|
|
def get_streaming_template(self) -> Template:
|
|
"""
|
|
Get the template for streaming.
|
|
|
|
Returns:
|
|
Template instance for this Answer node
|
|
"""
|
|
return Template.from_answer_template(self.node_data.answer)
|