add few workflow run codes

This commit is contained in:
takatost
2024-03-04 23:34:23 +08:00
parent 836376c6c8
commit d51d456d80
13 changed files with 254 additions and 183 deletions

View File

@ -112,6 +112,7 @@ class VariableEntity(BaseModel):
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None
class ExternalDataVariableEntity(BaseModel):

View File

@ -10,12 +10,14 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import QueueStopEvent
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.moderation.base import ModerationException
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@ -83,13 +85,16 @@ class AdvancedChatAppRunner(AppRunner):
result_generator = workflow_engine_manager.run_workflow(
app_model=app_record,
workflow=workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
user=user,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION: conversation.id,
}
},
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
)
for result in result_generator:

View File

View File

@ -1,157 +0,0 @@
import os
import sys
from typing import Any, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult
class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler."""
self.color = color
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any
) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
for sub_messages in messages:
for sub_message in sub_messages:
print_text(str(sub_message) + "\n", color='blue')
def on_llm_start(
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue')
print_text(prompts[0] + "\n", color='blue')
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str(
response.llm_output) + "\n", color='blue')
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
def on_chain_start(
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink')
def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
print_text("\n[on_tool_start] " + str(serialized), color='yellow')
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
try:
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
except ValueError:
thought = ''
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
print_text("\n[on_tool_end]\n", color='yellow')
if observation_prefix:
print_text(f"\n{observation_prefix}")
print_text(output, color='yellow')
if llm_prefix:
print_text(f"\n{llm_prefix}")
print_text("\n")
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow')
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run when agent ends."""
print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end)
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)
sys.stdout.flush()

View File

@ -0,0 +1,45 @@
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from models.workflow import WorkflowRun, WorkflowNodeExecution
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager):
self._queue_manager = queue_manager
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run started
"""
self._queue_manager.publish_workflow_started(
workflow_run_id=workflow_run.id,
pub_from=PublishFrom.TASK_PIPELINE
)
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run finished
"""
self._queue_manager.publish_workflow_finished(
workflow_run_id=workflow_run.id,
pub_from=PublishFrom.TASK_PIPELINE
)
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish_node_started(
workflow_node_execution_id=workflow_node_execution.id,
pub_from=PublishFrom.TASK_PIPELINE
)
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute finished
"""
self._queue_manager.publish_node_finished(
workflow_node_execution_id=workflow_node_execution.id,
pub_from=PublishFrom.TASK_PIPELINE
)

View File

View File

@ -0,0 +1,33 @@
from abc import abstractmethod
from models.workflow import WorkflowRun, WorkflowNodeExecution
class BaseWorkflowCallback:
@abstractmethod
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run finished
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute finished
"""
raise NotImplementedError

View File

@ -0,0 +1,7 @@
from abc import ABC
from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
pass

View File

@ -1,32 +1,21 @@
from abc import abstractmethod
from typing import Optional
from typing import Optional, Type
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
class BaseNode:
_node_type: NodeType
_node_data_cls: Type[BaseNodeData]
def __int__(self, node_config: dict) -> None:
self._node_config = node_config
def __init__(self, config: dict) -> None:
self._node_id = config.get("id")
if not self._node_id:
raise ValueError("Node ID is required.")
@abstractmethod
def run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
return self._run(
variable_pool=variable_pool,
run_args=run_args
)
self._node_data = self._node_data_cls(**config.get("data", {}))
@abstractmethod
def _run(self, variable_pool: Optional[VariablePool] = None,
@ -39,6 +28,22 @@ class BaseNode:
"""
raise NotImplementedError
def run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
"""
Run node entry
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
return self._run(
variable_pool=variable_pool,
run_args=run_args
)
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""

View File

@ -0,0 +1,27 @@
from typing import Optional
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class StartNodeData(BaseNodeData):
"""
- title (string) 节点标题
- desc (string) optional 节点描述
- type (string) 节点类型,固定为 start
- variables (array[object]) 表单变量列表
- type (string) 表单变量类型text-input, paragraph, select, number, files文件暂不支持自定义
- label (string) 控件展示标签名
- variable (string) 变量 key
- max_length (int) 最大长度,适用于 text-input 和 paragraph
- default (string) optional 默认值
- required (bool) optional是否必填默认 false
- hint (string) optional 提示信息
- options (array[string]) 选项值(仅 select 可用)
"""
type: str = NodeType.START.value
title: str
desc: Optional[str] = None
variables: list[VariableEntity] = []

View File

@ -1,5 +1,22 @@
from typing import Type, Optional
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode):
pass
_node_type = NodeType.START
_node_data_cls = StartNodeData
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
pass

View File

@ -1,6 +1,8 @@
import json
from collections.abc import Generator
from typing import Optional, Union
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode
@ -17,7 +19,7 @@ from core.workflow.nodes.variable_assigner.variable_assigner_node import Variabl
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowRunTriggeredFrom, WorkflowRun, WorkflowRunStatus, CreatedByRole
node_classes = {
NodeType.START: StartNode,
@ -108,17 +110,103 @@ class WorkflowEngineManager:
def run_workflow(self, app_model: App,
workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> Generator:
system_inputs: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> Generator:
"""
Run workflow
:param app_model: App instance
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
:return:
"""
# fetch workflow graph
graph = workflow.graph_dict
if not graph:
raise ValueError('workflow graph not found')
# init workflow run
workflow_run = self._init_workflow_run(
workflow=workflow,
triggered_from=triggered_from,
user=user,
user_inputs=user_inputs,
system_inputs=system_inputs
)
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started(workflow_run)
pass
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:return:
"""
# TODO
pass
try:
db.session.begin()
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
.for_update() \
.scalar() or 0
new_sequence_number = max_sequence + 1
# init workflow run
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
sequence_number=new_sequence_number,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=triggered_from.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps({**user_inputs, **system_inputs}),
status=WorkflowRunStatus.RUNNING.value,
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by_id=user.id
)
db.session.add(workflow_run)
db.session.commit()
except:
db.session.rollback()
raise
return workflow_run
def _get_entry_node(self, graph: dict) -> Optional[StartNode]:
"""
Get entry node
:param graph: workflow graph
:return:
"""
nodes = graph.get('nodes')
if not nodes:
return None
for node_config in nodes.items():
if node_config.get('type') == NodeType.START.value:
return StartNode(config=node_config)
return None