mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
fix workflow api return
This commit is contained in:
@ -30,3 +30,12 @@ class NodeType(Enum):
|
||||
if node_type.value == value:
|
||||
return node_type
|
||||
raise ValueError(f'invalid node type value {value}')
|
||||
|
||||
|
||||
class SystemVariable(Enum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION = 'conversation'
|
||||
|
||||
82
api/core/workflow/entities/variable_pool.py
Normal file
82
api/core/workflow/entities/variable_pool.py
Normal file
@ -0,0 +1,82 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list]
|
||||
|
||||
|
||||
class ValueType(Enum):
|
||||
"""
|
||||
Value Type Enum
|
||||
"""
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY = "array"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
class VariablePool:
|
||||
variables_mapping = {}
|
||||
|
||||
def __init__(self, system_variables: dict[SystemVariable, Any]) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
for system_variable, value in system_variables.items():
|
||||
self.append_variable('sys', [system_variable.value], value)
|
||||
|
||||
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
|
||||
"""
|
||||
Append variable
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list, like: ['result', 'text']
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
if node_id not in self.variables_mapping:
|
||||
self.variables_mapping[node_id] = {}
|
||||
|
||||
variable_key_list_hash = hash(tuple(variable_key_list))
|
||||
|
||||
self.variables_mapping[node_id][variable_key_list_hash] = value
|
||||
|
||||
def get_variable_value(self, variable_selector: list[str],
|
||||
target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
|
||||
"""
|
||||
Get variable
|
||||
:param variable_selector: include node_id and variables
|
||||
:param target_value_type: target value type
|
||||
:return:
|
||||
"""
|
||||
if len(variable_selector) < 2:
|
||||
raise ValueError('Invalid value selector')
|
||||
|
||||
node_id = variable_selector[0]
|
||||
if node_id not in self.variables_mapping:
|
||||
return None
|
||||
|
||||
# fetch variable keys, pop node_id
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
||||
variable_key_list_hash = hash(tuple(variable_key_list))
|
||||
|
||||
value = self.variables_mapping[node_id].get(variable_key_list_hash)
|
||||
|
||||
if target_value_type:
|
||||
if target_value_type == ValueType.STRING:
|
||||
return str(value)
|
||||
elif target_value_type == ValueType.NUMBER:
|
||||
return int(value)
|
||||
elif target_value_type == ValueType.OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('Invalid value type: object')
|
||||
elif target_value_type == ValueType.ARRAY:
|
||||
if not isinstance(value, list):
|
||||
raise ValueError('Invalid value type: array')
|
||||
|
||||
return value
|
||||
@ -1,7 +1,44 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class BaseNode:
|
||||
_node_type: NodeType
|
||||
|
||||
def __int__(self, node_config: dict) -> None:
|
||||
self._node_config = node_config
|
||||
|
||||
@abstractmethod
|
||||
def run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
if variable_pool is None and run_args is None:
|
||||
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
|
||||
|
||||
return self._run(
|
||||
variable_pool=variable_pool,
|
||||
run_args=run_args
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Generator
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode
|
||||
@ -14,7 +15,8 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser, Conversation
|
||||
from models.workflow import Workflow
|
||||
|
||||
node_classes = {
|
||||
@ -56,13 +58,20 @@ class WorkflowEngineManager:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
return self.get_workflow(app_model, app_model.workflow_id)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id
|
||||
Workflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
# return published workflow
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def get_default_configs(self) -> list[dict]:
|
||||
@ -96,3 +105,20 @@ class WorkflowEngineManager:
|
||||
return None
|
||||
|
||||
return default_config
|
||||
|
||||
def run_workflow(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> Generator:
|
||||
"""
|
||||
Run workflow
|
||||
:param app_model: App instance
|
||||
:param workflow: Workflow instance
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:return:
|
||||
"""
|
||||
# TODO
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user