mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
fix test
This commit is contained in:
@ -2,9 +2,9 @@ import json
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
GenerateRouteChunk,
|
||||
@ -29,11 +29,11 @@ class AnswerNode(BaseNode):
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||
generate_routes = AnswerStreamOutputManager.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
answer = ''
|
||||
for part in generate_routes:
|
||||
if part.type == "var":
|
||||
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = self.graph_runtime_state.variable_pool.get_variable_value(
|
||||
@ -72,67 +72,6 @@ class AnswerNode(BaseNode):
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
generate_routes = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
@ -141,7 +80,7 @@ class AnswerNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
160
api/core/workflow/nodes/answer/answer_stream_output_manager.py
Normal file
160
api/core/workflow/nodes/answer/answer_stream_output_manager.py
Normal file
@ -0,0 +1,160 @@
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamOutputManager:
|
||||
|
||||
@classmethod
|
||||
def init_stream_generate_routes(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
|
||||
) -> dict[str, AnswerStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selectors of answer nodes
|
||||
stream_generate_routes = {}
|
||||
for node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
generate_route = cls._extract_generate_route_selectors(node_config)
|
||||
streaming_node_ids = cls._get_streaming_node_ids(
|
||||
target_node_id=node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
|
||||
if not streaming_node_ids:
|
||||
continue
|
||||
|
||||
for streaming_node_id in streaming_node_ids:
|
||||
stream_generate_routes[streaming_node_id] = AnswerStreamGenerateRoute(
|
||||
answer_node_id=node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _get_streaming_node_ids(cls,
|
||||
target_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
edge_mapping: dict[str, list["GraphEdge"]]) -> list[str]: # type: ignore[name-defined]
|
||||
"""
|
||||
Get answer stream node IDs.
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_graph_edges = []
|
||||
for graph_edges in edge_mapping.values():
|
||||
for graph_edge in graph_edges:
|
||||
if graph_edge.target_node_id == target_node_id:
|
||||
ingoing_graph_edges.append(graph_edge)
|
||||
|
||||
if not ingoing_graph_edges:
|
||||
return []
|
||||
|
||||
streaming_node_ids = []
|
||||
for ingoing_graph_edge in ingoing_graph_edges:
|
||||
source_node_id = ingoing_graph_edge.source_node_id
|
||||
source_node = node_id_config_mapping.get(source_node_id)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value
|
||||
]:
|
||||
# Current node (answer nodes / multi-branch nodes / iteration nodes) cannot be stream node.
|
||||
streaming_node_ids.append(target_node_id)
|
||||
elif node_type == NodeType.START.value:
|
||||
# Current node is START node, can be stream node.
|
||||
streaming_node_ids.append(source_node_id)
|
||||
else:
|
||||
# Find the stream node forward.
|
||||
sub_streaming_node_ids = cls._get_streaming_node_ids(
|
||||
target_node_id=source_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
|
||||
if sub_streaming_node_ids:
|
||||
streaming_node_ids.extend(sub_streaming_node_ids)
|
||||
|
||||
return streaming_node_ids
|
||||
|
||||
@classmethod
|
||||
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = AnswerNodeData(**config.get("data", {}))
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
@ -8,27 +9,44 @@ class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
answer: str
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
class GenerateRouteChunk(BaseModel):
|
||||
"""
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
type: str
|
||||
|
||||
class ChunkType(Enum):
|
||||
VAR = "var"
|
||||
TEXT = "text"
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
||||
class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
type: str = "var"
|
||||
value_selector: list[str]
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
||||
"""generate route chunk type"""
|
||||
value_selector: list[str] = Field(..., description="value selector")
|
||||
|
||||
|
||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
type: str = "text"
|
||||
text: str
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
|
||||
"""generate route chunk type"""
|
||||
text: str = Field(..., description="text")
|
||||
|
||||
|
||||
class AnswerStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
ChatflowStreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str = Field(..., description="answer node ID")
|
||||
generate_route: list[GenerateRouteChunk] = Field(..., description="answer stream generate route")
|
||||
current_route_position: int = 0
|
||||
"""current generate route position"""
|
||||
|
||||
Reference in New Issue
Block a user