fix workflow api return

This commit is contained in:
takatost
2024-03-04 17:23:27 +08:00
parent c3eac450ce
commit 0cc0065f8c
12 changed files with 434 additions and 85 deletions

View File

@ -30,3 +30,12 @@ class NodeType(Enum):
if node_type.value == value:
return node_type
raise ValueError(f'invalid node type value {value}')
class SystemVariable(Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION = 'conversation'

View File

@ -0,0 +1,82 @@
from enum import Enum
from typing import Optional, Union, Any
from core.workflow.entities.node_entities import SystemVariable
VariableValue = Union[str, int, float, dict, list]
class ValueType(Enum):
"""
Value Type Enum
"""
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY = "array"
FILE = "file"
class VariablePool:
variables_mapping = {}
def __init__(self, system_variables: dict[SystemVariable, Any]) -> None:
# system variables
# for example:
# {
# 'query': 'abc',
# 'files': []
# }
for system_variable, value in system_variables.items():
self.append_variable('sys', [system_variable.value], value)
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
"""
Append variable
:param node_id: node id
:param variable_key_list: variable key list, like: ['result', 'text']
:param value: value
:return:
"""
if node_id not in self.variables_mapping:
self.variables_mapping[node_id] = {}
variable_key_list_hash = hash(tuple(variable_key_list))
self.variables_mapping[node_id][variable_key_list_hash] = value
def get_variable_value(self, variable_selector: list[str],
target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
"""
Get variable
:param variable_selector: include node_id and variables
:param target_value_type: target value type
:return:
"""
if len(variable_selector) < 2:
raise ValueError('Invalid value selector')
node_id = variable_selector[0]
if node_id not in self.variables_mapping:
return None
# fetch variable keys, pop node_id
variable_key_list = variable_selector[1:]
variable_key_list_hash = hash(tuple(variable_key_list))
value = self.variables_mapping[node_id].get(variable_key_list_hash)
if target_value_type:
if target_value_type == ValueType.STRING:
return str(value)
elif target_value_type == ValueType.NUMBER:
return int(value)
elif target_value_type == ValueType.OBJECT:
if not isinstance(value, dict):
raise ValueError('Invalid value type: object')
elif target_value_type == ValueType.ARRAY:
if not isinstance(value, list):
raise ValueError('Invalid value type: array')
return value

View File

@ -1,7 +1,44 @@
from abc import abstractmethod
from typing import Optional
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
class BaseNode:
_node_type: NodeType
def __int__(self, node_config: dict) -> None:
self._node_config = node_config
@abstractmethod
def run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
return self._run(
variable_pool=variable_pool,
run_args=run_args
)
@abstractmethod
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
raise NotImplementedError
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""

View File

@ -1,5 +1,6 @@
from typing import Optional
from typing import Optional, Union, Generator
from core.memory.token_buffer_memory import TokenBufferMemory
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode
@ -14,7 +15,8 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode
from extensions.ext_database import db
from models.model import App
from models.account import Account
from models.model import App, EndUser, Conversation
from models.workflow import Workflow
node_classes = {
@ -56,13 +58,20 @@ class WorkflowEngineManager:
return None
# fetch published workflow by workflow_id
return self.get_workflow(app_model, app_model.workflow_id)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id
Workflow.id == workflow_id
).first()
# return published workflow
# return workflow
return workflow
def get_default_configs(self) -> list[dict]:
@ -96,3 +105,20 @@ class WorkflowEngineManager:
return None
return default_config
def run_workflow(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> Generator:
"""
Run workflow
:param app_model: App instance
:param workflow: Workflow instance
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:return:
"""
# TODO
pass