mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
add answer output parse
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
@ -32,14 +31,49 @@ class AnswerNode(BaseNode):
|
||||
|
||||
variable_values[variable_selector.variable] = value
|
||||
|
||||
variable_keys = list(variable_values.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(node_data.answer)
|
||||
answer = template_parser.format(variable_values)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# publish answer as stream
|
||||
for word in answer:
|
||||
self.publish_text_chunk(word)
|
||||
time.sleep(10) # TODO for debug
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
split_template = [
|
||||
{
|
||||
"type": "var" if self._is_variable(part, variable_keys) else "text",
|
||||
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
|
||||
}
|
||||
for part in template.split('Ω') if part
|
||||
]
|
||||
|
||||
answer = []
|
||||
for part in split_template:
|
||||
if part["type"] == "var":
|
||||
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": value
|
||||
}
|
||||
# TODO File
|
||||
else:
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": part["value"]
|
||||
}
|
||||
|
||||
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
|
||||
answer[-1]["text"] += answer_part["text"]
|
||||
else:
|
||||
answer.append(answer_part)
|
||||
|
||||
if len(answer) == 1 and answer[0]["type"] == "text":
|
||||
answer = answer[0]["text"]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -49,6 +83,10 @@ class AnswerNode(BaseNode):
|
||||
}
|
||||
)
|
||||
|
||||
def _is_variable(self, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
||||
@ -6,7 +6,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
@ -80,16 +79,9 @@ class BaseNode(ABC):
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
except Exception as e:
|
||||
# process unhandled exception
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
|
||||
self.node_run_result = result
|
||||
return result
|
||||
|
||||
@ -2,9 +2,9 @@ from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -20,34 +20,14 @@ class EndNode(BaseNode):
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
outputs_config = node_data.outputs
|
||||
output_variables = node_data.outputs
|
||||
|
||||
outputs = None
|
||||
if outputs_config:
|
||||
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
|
||||
plain_text_selector = outputs_config.plain_text_selector
|
||||
if plain_text_selector:
|
||||
outputs = {
|
||||
'text': variable_pool.get_variable_value(
|
||||
variable_selector=plain_text_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
'text': ''
|
||||
}
|
||||
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
|
||||
structured_variables = outputs_config.structured_variables
|
||||
if structured_variables:
|
||||
outputs = {}
|
||||
for variable_selector in structured_variables:
|
||||
variable_value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
outputs[variable_selector.variable] = variable_value
|
||||
else:
|
||||
outputs = {}
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
variable_value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
outputs[variable_selector.variable] = variable_value
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
||||
@ -1,68 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class EndNodeOutputType(Enum):
|
||||
"""
|
||||
END Node Output Types.
|
||||
|
||||
none, plain-text, structured
|
||||
"""
|
||||
NONE = 'none'
|
||||
PLAIN_TEXT = 'plain-text'
|
||||
STRUCTURED = 'structured'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'OutputType':
|
||||
"""
|
||||
Get value of given output type.
|
||||
|
||||
:param value: output type value
|
||||
:return: output type
|
||||
"""
|
||||
for output_type in cls:
|
||||
if output_type.value == value:
|
||||
return output_type
|
||||
raise ValueError(f'invalid output type value {value}')
|
||||
|
||||
|
||||
class EndNodeDataOutputs(BaseModel):
|
||||
"""
|
||||
END Node Data Outputs.
|
||||
"""
|
||||
class OutputType(Enum):
|
||||
"""
|
||||
Output Types.
|
||||
"""
|
||||
NONE = 'none'
|
||||
PLAIN_TEXT = 'plain-text'
|
||||
STRUCTURED = 'structured'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'OutputType':
|
||||
"""
|
||||
Get value of given output type.
|
||||
|
||||
:param value: output type value
|
||||
:return: output type
|
||||
"""
|
||||
for output_type in cls:
|
||||
if output_type.value == value:
|
||||
return output_type
|
||||
raise ValueError(f'invalid output type value {value}')
|
||||
|
||||
type: OutputType = OutputType.NONE
|
||||
plain_text_selector: Optional[list[str]] = None
|
||||
structured_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
"""
|
||||
END Node Data.
|
||||
"""
|
||||
outputs: Optional[EndNodeDataOutputs] = None
|
||||
outputs: list[VariableSelector]
|
||||
|
||||
Reference in New Issue
Block a user