add answer output parse

This commit is contained in:
takatost
2024-03-13 23:00:28 +08:00
parent fd8fe15d28
commit fcd470fcac
9 changed files with 120 additions and 138 deletions

View File

@ -1,4 +1,3 @@
import time
from typing import cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
@ -32,14 +31,49 @@ class AnswerNode(BaseNode):
variable_values[variable_selector.variable] = value
variable_keys = list(variable_values.keys())
# format answer template
template_parser = PromptTemplateParser(node_data.answer)
answer = template_parser.format(variable_values)
template_variable_keys = template_parser.variable_keys
# publish answer as stream
for word in answer:
self.publish_text_chunk(word)
time.sleep(10) # TODO for debug
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
split_template = [
{
"type": "var" if self._is_variable(part, variable_keys) else "text",
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
}
for part in template.split('Ω') if part
]
answer = []
for part in split_template:
if part["type"] == "var":
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
answer_part = {
"type": "text",
"text": value
}
# TODO File
else:
answer_part = {
"type": "text",
"text": part["value"]
}
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
answer[-1]["text"] += answer_part["text"]
else:
answer.append(answer_part)
if len(answer) == 1 and answer[0]["type"] == "text":
answer = answer[0]["text"]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -49,6 +83,10 @@ class AnswerNode(BaseNode):
}
)
def _is_variable(self, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""

View File

@ -6,7 +6,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class UserFrom(Enum):
@ -80,16 +79,9 @@ class BaseNode(ABC):
:param variable_pool: variable pool
:return:
"""
try:
result = self._run(
variable_pool=variable_pool
)
except Exception as e:
# process unhandled exception
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result

View File

@ -2,9 +2,9 @@ from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs
from core.workflow.nodes.end.entities import EndNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -20,34 +20,14 @@ class EndNode(BaseNode):
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
outputs_config = node_data.outputs
output_variables = node_data.outputs
outputs = None
if outputs_config:
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
plain_text_selector = outputs_config.plain_text_selector
if plain_text_selector:
outputs = {
'text': variable_pool.get_variable_value(
variable_selector=plain_text_selector,
target_value_type=ValueType.STRING
)
}
else:
outputs = {
'text': ''
}
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
structured_variables = outputs_config.structured_variables
if structured_variables:
outputs = {}
for variable_selector in structured_variables:
variable_value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = variable_value
else:
outputs = {}
outputs = {}
for variable_selector in output_variables:
variable_value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = variable_value
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

View File

@ -1,68 +1,9 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class EndNodeOutputType(Enum):
"""
END Node Output Types.
none, plain-text, structured
"""
NONE = 'none'
PLAIN_TEXT = 'plain-text'
STRUCTURED = 'structured'
@classmethod
def value_of(cls, value: str) -> 'OutputType':
"""
Get value of given output type.
:param value: output type value
:return: output type
"""
for output_type in cls:
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
class EndNodeDataOutputs(BaseModel):
"""
END Node Data Outputs.
"""
class OutputType(Enum):
"""
Output Types.
"""
NONE = 'none'
PLAIN_TEXT = 'plain-text'
STRUCTURED = 'structured'
@classmethod
def value_of(cls, value: str) -> 'OutputType':
"""
Get value of given output type.
:param value: output type value
:return: output type
"""
for output_type in cls:
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
type: OutputType = OutputType.NONE
plain_text_selector: Optional[list[str]] = None
structured_variables: Optional[list[VariableSelector]] = None
class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: Optional[EndNodeDataOutputs] = None
outputs: list[VariableSelector]