answer stream output support

This commit is contained in:
takatost
2024-03-14 20:49:53 +08:00
parent 1cfeb989f7
commit 12eb236364
10 changed files with 413 additions and 90 deletions

View File

@ -4,7 +4,12 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base_node import BaseNode
from models.workflow import WorkflowNodeExecutionStatus
@ -22,6 +27,40 @@ class AnswerNode(BaseNode):
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
# generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data)
answer = []
for part in generate_routes:
if part.type == "var":
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = variable_pool.get_variable_value(
variable_selector=value_selector,
target_value_type=ValueType.STRING
)
answer_part = {
"type": "text",
"text": value
}
# TODO File
else:
part = cast(TextGenerateRouteChunk, part)
answer_part = {
"type": "text",
"text": part.text
}
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
answer[-1]["text"] += answer_part["text"]
else:
answer.append(answer_part)
if len(answer) == 1 and answer[0]["type"] == "text":
answer = answer[0]["text"]
# re-fetch variable values
variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
@ -31,7 +70,39 @@ class AnswerNode(BaseNode):
variable_values[variable_selector.variable] = value
variable_keys = list(variable_values.keys())
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
outputs={
"answer": answer
}
)
@classmethod
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(node_data.answer)
@ -44,46 +115,24 @@ class AnswerNode(BaseNode):
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
split_template = [
{
"type": "var" if self._is_variable(part, variable_keys) else "text",
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
}
for part in template.split('Ω') if part
]
generate_routes = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
answer = []
for part in split_template:
if part["type"] == "var":
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
answer_part = {
"type": "text",
"text": value
}
# TODO File
else:
answer_part = {
"type": "text",
"text": part["value"]
}
return generate_routes
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
answer[-1]["text"] += answer_part["text"]
else:
answer.append(answer_part)
if len(answer) == 1 and answer[0]["type"] == "text":
answer = answer[0]["text"]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
outputs={
"answer": answer
}
)
def _is_variable(self, part, variable_keys):
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys

View File

@ -1,3 +1,6 @@
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
@ -8,3 +11,26 @@ class AnswerNodeData(BaseNodeData):
"""
variables: list[VariableSelector] = []
answer: str
class GenerateRouteChunk(BaseModel):
"""
Generate Route Chunk.
"""
type: str
class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: str = "var"
value_selector: list[str]
class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: str = "text"
text: str

View File

@ -86,17 +86,22 @@ class BaseNode(ABC):
self.node_run_result = result
return result
def publish_text_chunk(self, text: str) -> None:
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text
text=text,
metadata={
"node_type": self.node_type,
"value_selector": value_selector
}
)
@classmethod

View File

@ -169,7 +169,7 @@ class LLMNode(BaseNode):
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text)
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
if not model:
model = result.model