mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
answer stream output support
This commit is contained in:
@ -4,7 +4,12 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
@ -22,6 +27,40 @@ class AnswerNode(BaseNode):
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
answer = []
|
||||
for part in generate_routes:
|
||||
if part.type == "var":
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=value_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": value
|
||||
}
|
||||
# TODO File
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": part.text
|
||||
}
|
||||
|
||||
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
|
||||
answer[-1]["text"] += answer_part["text"]
|
||||
else:
|
||||
answer.append(answer_part)
|
||||
|
||||
if len(answer) == 1 and answer[0]["type"] == "text":
|
||||
answer = answer[0]["text"]
|
||||
|
||||
# re-fetch variable values
|
||||
variable_values = {}
|
||||
for variable_selector in node_data.variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
@ -31,7 +70,39 @@ class AnswerNode(BaseNode):
|
||||
|
||||
variable_values[variable_selector.variable] = value
|
||||
|
||||
variable_keys = list(variable_values.keys())
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variable_values,
|
||||
outputs={
|
||||
"answer": answer
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(node_data.answer)
|
||||
@ -44,46 +115,24 @@ class AnswerNode(BaseNode):
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
split_template = [
|
||||
{
|
||||
"type": "var" if self._is_variable(part, variable_keys) else "text",
|
||||
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
|
||||
}
|
||||
for part in template.split('Ω') if part
|
||||
]
|
||||
generate_routes = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
answer = []
|
||||
for part in split_template:
|
||||
if part["type"] == "var":
|
||||
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": value
|
||||
}
|
||||
# TODO File
|
||||
else:
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": part["value"]
|
||||
}
|
||||
return generate_routes
|
||||
|
||||
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
|
||||
answer[-1]["text"] += answer_part["text"]
|
||||
else:
|
||||
answer.append(answer_part)
|
||||
|
||||
if len(answer) == 1 and answer[0]["type"] == "text":
|
||||
answer = answer[0]["text"]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variable_values,
|
||||
outputs={
|
||||
"answer": answer
|
||||
}
|
||||
)
|
||||
|
||||
def _is_variable(self, part, variable_keys):
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
@ -8,3 +11,26 @@ class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
variables: list[VariableSelector] = []
|
||||
answer: str
|
||||
|
||||
|
||||
class GenerateRouteChunk(BaseModel):
|
||||
"""
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
type: str
|
||||
|
||||
|
||||
class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
type: str = "var"
|
||||
value_selector: list[str]
|
||||
|
||||
|
||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
type: str = "text"
|
||||
text: str
|
||||
|
||||
@ -86,17 +86,22 @@ class BaseNode(ABC):
|
||||
self.node_run_result = result
|
||||
return result
|
||||
|
||||
def publish_text_chunk(self, text: str) -> None:
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
:param text: chunk text
|
||||
:param value_selector: value selector
|
||||
:return:
|
||||
"""
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_node_text_chunk(
|
||||
node_id=self.node_id,
|
||||
text=text
|
||||
text=text,
|
||||
metadata={
|
||||
"node_type": self.node_type,
|
||||
"value_selector": value_selector
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -169,7 +169,7 @@ class LLMNode(BaseNode):
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
self.publish_text_chunk(text=text)
|
||||
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
Reference in New Issue
Block a user