add start, end, direct answer node

This commit is contained in:
takatost
2024-03-07 15:43:55 +08:00
parent 3e54cb26be
commit 8684b172d2
14 changed files with 274 additions and 31 deletions

View File

@ -1,4 +1,4 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Optional
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
@ -8,7 +8,7 @@ from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class BaseNode:
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType

View File

@ -1,5 +1,54 @@
import time
from typing import Optional, cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class DirectAnswerNode(BaseNode):
pass
_node_data_cls = DirectAnswerNodeData
node_type = NodeType.DIRECT_ANSWER
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
if variable_pool is None and run_args:
raise ValueError("Not support single step debug.")
variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector,
target_value_type=ValueType.STRING
)
variable_values[variable_selector.variable] = value
# format answer template
template_parser = PromptTemplateParser(node_data.answer)
answer = template_parser.format(variable_values)
# publish answer as stream
for word in answer:
self.publish_text_chunk(word)
time.sleep(0.01)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
output={
"answer": answer
}
)

View File

@ -0,0 +1,10 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class DirectAnswerNodeData(BaseNodeData):
"""
DirectAnswer Node Data.
"""
variables: list[VariableSelector] = []
answer: str

View File

@ -1,5 +1,60 @@
from typing import Optional, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs
from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode):
pass
_node_data_cls = EndNodeData
node_type = NodeType.END
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
outputs_config = node_data.outputs
if variable_pool is not None:
outputs = None
if outputs_config:
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
plain_text_selector = outputs_config.plain_text_selector
if plain_text_selector:
outputs = {
'text': variable_pool.get_variable_value(
variable_selector=plain_text_selector,
target_value_type=ValueType.STRING
)
}
else:
outputs = {
'text': ''
}
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
structured_variables = outputs_config.structured_variables
if structured_variables:
outputs = {}
for variable_selector in structured_variables:
variable_value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = variable_value
else:
outputs = {}
else:
raise ValueError("Not support single step debug.")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs
)

View File

@ -1,4 +1,10 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class EndNodeOutputType(Enum):
@ -23,3 +29,40 @@ class EndNodeOutputType(Enum):
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
class EndNodeDataOutputs(BaseModel):
"""
END Node Data Outputs.
"""
class OutputType(Enum):
"""
Output Types.
"""
NONE = 'none'
PLAIN_TEXT = 'plain-text'
STRUCTURED = 'structured'
@classmethod
def value_of(cls, value: str) -> 'OutputType':
"""
Get value of given output type.
:param value: output type value
:return: output type
"""
for output_type in cls:
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
type: OutputType = OutputType.NONE
plain_text_selector: Optional[list[str]] = None
structured_variables: Optional[list[VariableSelector]] = None
class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: Optional[EndNodeDataOutputs] = None

View File

@ -0,0 +1,8 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
pass

View File

@ -1,9 +1,28 @@
from typing import Optional
from typing import Optional, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
node_type = NodeType.LLM
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
pass
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""

View File

@ -1,23 +1,9 @@
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class StartNodeData(BaseNodeData):
"""
- title (string) 节点标题
- desc (string) optional 节点描述
- type (string) 节点类型,固定为 start
- variables (array[object]) 表单变量列表
- type (string) 表单变量类型text-input, paragraph, select, number, files文件暂不支持自定义
- label (string) 控件展示标签名
- variable (string) 变量 key
- max_length (int) 最大长度,适用于 text-input 和 paragraph
- default (string) optional 默认值
- required (bool) optional是否必填默认 false
- hint (string) optional 提示信息
- options (array[string]) 选项值(仅 select 可用)
Start Node Data
"""
type: str = NodeType.START.value
variables: list[VariableEntity] = []

View File

@ -1,9 +1,11 @@
from typing import Optional
from typing import Optional, cast
from core.workflow.entities.node_entities import NodeType
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode):
@ -11,12 +13,58 @@ class StartNode(BaseNode):
node_type = NodeType.START
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
pass
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
variables = node_data.variables
# Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, run_args)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
)
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"Input form variable {variable} is required")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs