feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost
2024-09-10 15:23:16 +08:00
committed by GitHub
parent 5da0182800
commit dabfd74622
156 changed files with 11158 additions and 5605 deletions

View File

@ -4,12 +4,10 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, Workflow
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -60,13 +53,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -154,7 +148,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True):
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -171,16 +166,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
@ -191,14 +176,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
conversation_id=None,
inputs={},
query='',
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
extras={
"auto_generate_conversation_name": False
},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
@ -211,17 +198,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=conversation,
conversation=None,
stream=stream
)
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True):
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
@ -236,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
# db.session.refresh(conversation)
db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -248,67 +246,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'context': contextvars.copy_context(),
})
@ -334,6 +277,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context) -> None:
"""
@ -349,28 +293,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var.set(val)
with flask_app.app_context():
try:
runner = AdvancedChatAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
# get message
message = self._get_message(message_id)
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
)
# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
runner.run()
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:

View File

@ -1,49 +1,67 @@
import logging
import os
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models import App, Message, Workflow
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(AppRunner):
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
"""
AdvancedChat Application Runner
"""
def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
message: Message,
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
def run(self) -> None:
"""
Run application
:return:
"""
app_config = application_generate_entity.app_config
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
@ -54,101 +72,133 @@ class AdvancedChatAppRunner(AppRunner):
if not workflow:
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
query = application_generate_entity.query
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
):
return
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
self.conversation.dialogue_count += 1
conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
for event in generator:
self._handle_event(workflow_entry, event)
def handle_input_moderation(
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
@ -167,30 +217,23 @@ class AdvancedChatAppRunner(AppRunner):
message_id=message_id,
)
except ModerationException as e:
self._stream_output(
queue_manager=queue_manager,
self._complete_with_stream_output(
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True
return False
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param queue_manager: application queue manager
:param app_generate_entity: application generate entity
"""
# annotation reply
@ -203,37 +246,32 @@ class AdvancedChatAppRunner(AppRunner):
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
self._publish_event(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
)
self._stream_output(
queue_manager=queue_manager,
self._complete_with_stream_output(
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
def _complete_with_stream_output(self,
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:param stream: stream
:return:
"""
if stream:
index = 0
for token in text:
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
self._publish_event(
QueueTextChunkEvent(
text=text
)
)
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
self._publish_event(
QueueStopEvent(stopped_by=stopped_by)
)

View File

@ -2,9 +2,8 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -22,6 +21,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@ -31,34 +33,28 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: AdvancedChatTaskState
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
@ -106,7 +101,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
@ -114,12 +108,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariableKey.USER_ID: user_id,
}
self._task_state = AdvancedChatTaskState(
usage=LLMUsage.empty_usage()
)
self._task_state = WorkflowTaskState()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self):
@ -140,6 +130,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
@ -199,17 +190,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -220,9 +212,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
if not tts_publisher:
break
audio_trunk = publisher.checkAndGetAudio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -240,34 +232,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if (message.event
and getattr(message.event, 'metadata', None)
and message.event.metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata
and message.event.execution_metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
event = message.event
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
if isinstance(event, QueueErrorEvent):
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._handle_workflow_start()
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
# init workflow run
workflow_run = self._handle_workflow_run_start()
self._refetch_message()
self._message.workflow_run_id = workflow_run.id
db.session.commit()
@ -279,133 +271,242 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
if not workflow_run:
raise Exception('Workflow run not initialized.')
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# reset current route position to 0
self._task_state.current_stream_generate_state.current_route_position = 0
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
# stream outputs when node finished
generator = self._generate_stream_outputs_when_node_finished()
if generator:
yield from generator
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
yield self._workflow_node_finish_to_stream_response(
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, conversation_id=self._conversation.id, trace_manager=trace_manager
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if workflow_run:
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._message_end_to_stream_response()
break
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message()
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
else:
continue
if publisher:
publisher.publish(None)
# publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self) -> None:
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
"""
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._refetch_message()
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
@ -432,7 +533,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras['metadata'] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
@ -440,323 +544,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
**extras
)
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
self._task_state.current_stream_generate_state.current_route_position += 1
continue
route_chunk_node_id = value_selector[0]
if route_chunk_node_id == 'sys':
# system variable
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector
) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key) if outputs else None
else:
value = value.get(key)
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text:
self._task_state.answer += text
yield self._message_to_stream_response(text, self._message.id)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
@ -782,3 +569,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._output_moderation_handler.append_new_token(text)
return False
def _refetch_message(self) -> None:
"""
Refetch message.
:return:
"""
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if message:
self._message = message

View File

@ -1,203 +0,0 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager._publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager._publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self._queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
def convert(cls, response: Union[
AppBlockingResponse,
Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom):
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional, Union
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -347,7 +347,7 @@ class AppRunner:
self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: dict,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> tuple[bool, dict, str]:

View File

@ -4,7 +4,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> Generator[str, None, None]: ...
@overload
@ -50,16 +52,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> dict: ...
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
):
"""
Generate App response.
@ -71,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args['inputs']
@ -118,16 +125,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id
)
def _generate(
self, app_model: App,
self, *,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
workflow_thread_pool_id: Optional[str] = None
) -> dict[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@ -137,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@ -148,10 +159,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
'workflow_thread_pool_id': workflow_thread_pool_id
})
worker_thread.start()
@ -175,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True):
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -192,10 +204,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
@ -211,7 +219,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
extras={
"auto_generate_conversation_name": False
},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context) -> None:
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
for var, val in context.items():
@ -244,22 +256,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
with flask_app.app_context():
try:
# workflow app
runner = WorkflowAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id
)
runner.run()
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
@ -271,14 +274,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,

View File

@ -4,46 +4,61 @@ from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, EndUser
from models.workflow import Workflow
from models.workflow import WorkflowType
logger = logging.getLogger(__name__)
class WorkflowAppRunner:
class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def run(self) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:return:
"""
app_config = application_generate_entity.app_config
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
user_id = self.application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
@ -53,80 +68,64 @@ class WorkflowAppRunner:
if not workflow:
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
files = application_generate_entity.files
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
if not app_record.workflow_id:
raise ValueError('Workflow not initialized')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
generator = workflow_entry.run(
callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -1,3 +1,4 @@
import json
import logging
import time
from collections.abc import Generator
@ -15,10 +16,12 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -32,19 +35,16 @@ from core.app.entities.task_entities import (
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
TextReplaceStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStreamGenerateNodes,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -52,8 +52,8 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
@ -96,11 +95,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
SystemVariableKey.USER_ID: user_id
}
self._task_state = WorkflowTaskState(
iteration_nested_node_ids=[]
)
self._stream_generate_nodes = self._get_stream_generate_nodes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._task_state = WorkflowTaskState()
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -129,23 +124,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowFinishStreamResponse):
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
outputs=workflow_run.outputs_dict,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp())
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at)
)
)
@ -161,9 +153,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
To stream response.
:return:
"""
workflow_run_id = None
for stream_response in generator:
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(
workflow_run_id=self._task_state.workflow_run_id,
workflow_run_id=workflow_run_id,
stream_response=stream_response
)
@ -178,17 +174,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -198,9 +195,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
if not tts_publisher:
break
audio_trunk = publisher.checkAndGetAudio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -218,69 +215,159 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
graph_runtime_state = None
workflow_run = None
if isinstance(event, QueueErrorEvent):
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._handle_workflow_start()
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
# init workflow run
workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
if not workflow_run:
raise Exception('Workflow run not initialized.')
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
yield self._workflow_node_finish_to_stream_response(
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, trace_manager=trace_manager
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
@ -295,22 +382,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._text_replace_to_stream_response(event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
@ -329,15 +411,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# not save log for debugging
return
workflow_app_log = WorkflowAppLog(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
created_from=created_from.value,
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
created_by=self._user.id,
)
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = workflow_run.tenant_id
workflow_app_log.app_id = workflow_run.app_id
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
workflow_app_log.created_by = self._user.id
db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
@ -354,180 +436,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
return response
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
"""
Text replace to stream response.
:param text: text
:return:
"""
return TextReplaceStreamResponse(
task_id=self._application_generate_entity.task_id,
text=TextReplaceStreamResponse.Data(text=text)
)
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
"""
Get stream generate nodes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
end_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.END.value
]
# parse stream output node value selectors of end nodes
stream_generate_routes = {}
for node_config in end_node_configs:
# get generate route for stream output
end_node_id = node_config['id']
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
end_node_id=end_node_id,
stream_node_ids=generate_nodes
)
return stream_generate_routes
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get end start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
if node_id not in stream_node_ids:
continue
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
# get chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
if not route_chunk_node_execution:
continue
outputs = route_chunk_node_execution.outputs_dict
if not outputs:
continue
# get value from outputs
text = outputs.get('text')
if text:
self._task_state.answer += text
yield self._text_chunk_to_stream_response(text)
db.session.close()
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return False
if 'node_id' not in event.metadata:
return False
node_id = event.metadata.get('node_id')
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
return True
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@ -1,200 +0,0 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager.publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager.publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
pass

View File

@ -0,0 +1,379 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager):
self.queue_manager = queue_manager
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
Init graph
"""
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init graph
graph = Graph.init(
graph_config=graph_config
)
if not graph:
raise ValueError('graph not found in workflow')
return graph
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError('workflow graph not found')
graph_config = cast(dict[str, Any], graph_config)
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# filter nodes only in iteration
node_configs = [
node for node in graph_config.get('nodes', [])
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
]
graph_config['nodes'] = node_configs
node_ids = [node.get('id') for node in node_configs]
# filter edges only in iteration
edge_configs = [
edge for edge in graph_config.get('edges', [])
if (edge.get('source') is None or edge.get('source') in node_ids)
and (edge.get('target') is None or edge.get('target') in node_ids)
]
graph_config['edges'] = edge_configs
# init graph
graph = Graph.init(
graph_config=graph_config,
root_node_id=node_id
)
if not graph:
raise ValueError('graph not found in workflow')
# fetch node config from node id
iteration_node_config = None
for node in node_configs:
if node.get('id') == node_id:
iteration_node_config = node
break
if not iteration_node_config:
raise ValueError('iteration node id not found in workflow graph')
# Get node class
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,
config=iteration_node_config
)
except NotImplementedError:
variable_mapping = {}
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
"""
Handle event
:param workflow_entry: workflow entry
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
)
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(
QueueWorkflowSucceededEvent(outputs=event.outputs)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(
QueueWorkflowFailedEvent(error=event.error)
)
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunSucceededEvent):
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {},
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error
else "Unknown error",
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
error=event.error
)
)
elif isinstance(event, IterationRunStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata
)
)
elif isinstance(event, IterationRunNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -1,10 +1,24 @@
from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self.print_text("\n[on_workflow_run_started]", color='pink')
def on_event(
self,
event: GraphEngineEvent
) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color='pink')
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color='green')
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(
event=event
)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(
event=event
)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(
event=event
)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(
event=event
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(
event=event
)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(
event=event
)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(
event=event
)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(
event=event
)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(
event=event
)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self.print_text("\n[on_workflow_run_succeeded]", color='green')
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self.print_text("\n[on_workflow_run_failed]", color='red')
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
def on_workflow_node_execute_started(
self,
event: NodeRunStartedEvent
) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
self.print_text(f"Node ID: {node_id}", color='yellow')
self.print_text(f"Type: {node_type.value}", color='yellow')
self.print_text(f"Index: {node_run_index}", color='yellow')
if predecessor_node_id:
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
self.print_text(f"Node ID: {event.node_id}", color='yellow')
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
self.print_text(f"Type: {event.node_type.value}", color='yellow')
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
def on_workflow_node_execute_succeeded(
self,
event: NodeRunSucceededEvent
) -> None:
"""
Workflow node execute succeeded
"""
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
self.print_text(f"Node ID: {node_id}", color='green')
self.print_text(f"Type: {node_type.value}", color='green')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
color='green')
route_node_state = event.route_node_state
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
self.print_text("\n[NodeRunSucceededEvent]", color='green')
self.print_text(f"Node ID: {event.node_id}", color='green')
self.print_text(f"Node Title: {event.node_data.title}", color='green')
self.print_text(f"Type: {event.node_type.value}", color='green')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='green')
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='green')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='green')
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color='green')
def on_workflow_node_execute_failed(
self,
event: NodeRunFailedEvent
) -> None:
"""
Workflow node execute failed
"""
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
self.print_text(f"Node ID: {node_id}", color='red')
self.print_text(f"Type: {node_type.value}", color='red')
self.print_text(f"Error: {error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
route_node_state = event.route_node_state
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
self.print_text("\n[NodeRunFailedEvent]", color='red')
self.print_text(f"Node ID: {event.node_id}", color='red')
self.print_text(f"Node Title: {event.node_data.title}", color='red')
self.print_text(f"Type: {event.node_type.value}", color='red')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='red')
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='red')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='red')
def on_node_text_chunk(
self,
event: NodeRunStreamChunkEvent
) -> None:
"""
Publish text chunk
"""
if not self.current_node_id or self.current_node_id != node_id:
self.current_node_id = node_id
self.print_text('\n[on_node_text_chunk]')
self.print_text(f"Node ID: {node_id}")
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text('\n[NodeRunStreamChunkEvent]')
self.print_text(f"Node ID: {route_node_state.node_id}")
self.print_text(text, color="pink", end="")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(
self,
event: ParallelBranchRunStartedEvent
) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
def on_workflow_parallel_completed(
self,
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue'
elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red'
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(
self,
event: IterationRunStartedEvent
) -> None:
"""
Publish iteration started
"""
self.print_text("\n[on_workflow_iteration_started]", color='blue')
self.print_text(f"Node ID: {node_id}", color='blue')
self.print_text("\n[IterationRunStartedEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[dict]) -> None:
def on_workflow_iteration_next(
self,
event: IterationRunNextEvent
) -> None:
"""
Publish iteration next
"""
self.print_text("\n[on_workflow_iteration_next]", color='blue')
self.print_text("\n[IterationRunNextEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text(f"Iteration Index: {event.index}", color='blue')
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
def on_workflow_iteration_completed(
self,
event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self.print_text("\n[on_workflow_event]", color='blue')
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
def print_text(
self, text: str, color: Optional[str] = None, end: str = "\n"

View File

@ -1,3 +1,4 @@
from datetime import datetime
from enum import Enum
from typing import Any, Optional
@ -5,7 +6,8 @@ from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
class QueueEvent(str, Enum):
@ -31,6 +33,9 @@ class QueueEvent(str, Enum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
ERROR = "error"
PING = "ping"
STOP = "stop"
@ -38,7 +43,7 @@ class QueueEvent(str, Enum):
class AppQueueEvent(BaseModel):
"""
QueueEvent entity
QueueEvent abstract entity
"""
event: QueueEvent
@ -46,6 +51,7 @@ class AppQueueEvent(BaseModel):
class QueueLLMChunkEvent(AppQueueEvent):
"""
QueueLLMChunkEvent entity
Only for basic mode apps
"""
event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
@ -55,14 +61,24 @@ class QueueIterationStartEvent(AppQueueEvent):
QueueIterationStartEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_START
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: dict = None
inputs: Optional[dict[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None
metadata: Optional[dict[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent):
"""
@ -71,8 +87,18 @@ class QueueIterationNextEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.ITERATION_NEXT
index: int
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
@ -93,13 +119,30 @@ class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
outputs: dict
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
error: Optional[str] = None
class QueueTextChunkEvent(AppQueueEvent):
"""
@ -107,7 +150,10 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
metadata: Optional[dict] = None
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueAgentMessageEvent(AppQueueEvent):
@ -132,6 +178,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueAnnotationReplyEvent(AppQueueEvent):
@ -162,6 +210,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
QueueWorkflowStartedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState
class QueueWorkflowSucceededEvent(AppQueueEvent):
@ -169,6 +218,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
QueueWorkflowSucceededEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None
class QueueWorkflowFailedEvent(AppQueueEvent):
@ -185,11 +235,23 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_STARTED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
class QueueNodeSucceededEvent(AppQueueEvent):
@ -198,14 +260,26 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
execution_metadata: Optional[dict] = None
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None
@ -216,13 +290,25 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict] = None
outputs: Optional[dict] = None
process_data: Optional[dict] = None
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
error: str
@ -274,10 +360,23 @@ class QueueStopEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy
def get_stop_reason(self) -> str:
"""
To stop reason
"""
reason_mapping = {
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
}
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
class QueueMessage(BaseModel):
"""
QueueMessage entity
QueueMessage abstract entity
"""
task_id: str
app_mode: str
@ -297,3 +396,52 @@ class WorkflowQueueMessage(QueueMessage):
WorkflowQueueMessage entity
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
error: str

View File

@ -3,40 +3,11 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
from models.workflow import WorkflowNodeExecutionStatus
class WorkflowStreamGenerateNodes(BaseModel):
"""
WorkflowStreamGenerateNodes entity
"""
end_node_id: str
stream_node_ids: list[str]
class ChatflowStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
current_route_position: int = 0
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution_id: str
node_type: NodeType
start_at: float
class TaskState(BaseModel):
"""
TaskState entity
@ -57,27 +28,6 @@ class WorkflowTaskState(TaskState):
"""
answer: str = ""
workflow_run_id: Optional[str] = None
start_at: Optional[float] = None
total_tokens: int = 0
total_steps: int = 0
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
AdvancedChatTaskState entity
"""
usage: LLMUsage
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
class StreamEvent(Enum):
"""
@ -97,6 +47,8 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@ -267,6 +219,11 @@ class NodeStartStreamResponse(StreamResponse):
inputs: Optional[dict] = None
created_at: int
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -286,7 +243,12 @@ class NodeStartStreamResponse(StreamResponse):
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"created_at": self.data.created_at,
"extras": {}
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
}
}
@ -316,6 +278,11 @@ class NodeFinishStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Optional[list[dict]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@ -342,9 +309,58 @@ class NodeFinishStreamResponse(StreamResponse):
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": []
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
}
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):
@ -364,6 +380,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
extras: dict = {}
metadata: dict = {}
inputs: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
@ -387,6 +405,8 @@ class IterationNodeNextStreamResponse(StreamResponse):
created_at: int
pre_iteration_output: Optional[Any] = None
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -408,8 +428,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
title: str
outputs: Optional[dict] = None
created_at: int
extras: dict = None
inputs: dict = None
extras: Optional[dict] = None
inputs: Optional[dict] = None
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
elapsed_time: float
@ -417,6 +437,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
@ -488,7 +510,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
"""
WorkflowAppStreamResponse entity
"""
workflow_run_id: str
workflow_run_id: Optional[str] = None
class AppBlockingResponse(BaseModel):
@ -562,25 +584,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class WorkflowIterationState(BaseModel):
"""
WorkflowIterationState entity
"""
class Data(BaseModel):
"""
Data entity
"""
parent_iteration_id: Optional[str] = None
iteration_id: str
current_index: int
iteration_steps_boundary: list[int] = None
node_execution_id: str
started_at: float
inputs: dict = None
total_tokens: int = 0
node_data: BaseNodeData
current_iterations: dict[str, Data] = None

View File

@ -68,16 +68,18 @@ class BasedGenerateTaskPipeline:
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
if message:
message = db.session.query(Message).filter(Message.id == message.id).first()
err_desc = self._error_to_desc(err)
message.status = 'error'
message.error = err_desc
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
db.session.commit()
if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = 'error'
refetch_message.error = err_desc
db.session.commit()
return err
def _error_to_desc(cls, e: Exception) -> str:
def _error_to_desc(self, e: Exception) -> str:
"""
Error to desc.
:param e: exception

View File

@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
@ -16,11 +15,11 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.tool_file_manager import ToolFileManager
@ -36,7 +35,7 @@ class MessageCycleManage:
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
"""
@ -45,6 +44,9 @@ class MessageCycleManage:
:param query: query
:return: thread
"""
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
@ -52,7 +54,7 @@ class MessageCycleManage:
if auto_generate_conversation_name and is_first_message:
# start generate thread
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore
'conversation_id': conversation.id,
'query': query
})
@ -75,6 +77,9 @@ class MessageCycleManage:
.first()
)
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
app_model = conversation.app
if not app_model:
@ -121,34 +126,13 @@ class MessageCycleManage:
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata['retriever_resources'] = event.retriever_resources
def _get_response_metadata(self) -> dict:
"""
Get response metadata by invoke from.
:return:
"""
metadata = {}
# show_retrieve_source
if 'retriever_resources' in self._task_state.metadata:
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
# show annotation reply
if 'annotation_reply' in self._task_state.metadata:
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
# show usage
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
metadata['usage'] = self._task_state.metadata['usage']
return metadata
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
Message file to stream response.
:param event: event
:return:
"""
message_file: MessageFile = (
message_file = (
db.session.query(MessageFile)
.filter(MessageFile.id == event.message_file_id)
.first()

View File

@ -1,33 +1,41 @@
import json
import time
from datetime import datetime, timezone
from typing import Optional, Union, cast
from typing import Any, Optional, Union, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
NodeExecutionInfo,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -41,54 +49,56 @@ from models.workflow import (
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
from services.workflow_service import WorkflowService
class WorkflowCycleManage(WorkflowIterationCycleManage):
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:return:
"""
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
.scalar() or 0
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def _handle_workflow_run_start(self) -> WorkflowRun:
max_sequence = (
db.session.query(db.func.max(WorkflowRun.sequence_number))
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
.filter(WorkflowRun.app_id == self._workflow.app_id)
.scalar()
or 0
)
new_sequence_number = max_sequence + 1
inputs = {**user_inputs}
for key, value in (system_inputs or {}).items():
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == 'conversation':
continue
inputs[f'sys.{key.value}'] = value
inputs = WorkflowEngineManager.handle_special_values(inputs)
inputs = WorkflowEntry.handle_special_values(inputs)
triggered_from= (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
)
# init workflow run
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
sequence_number=new_sequence_number,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=triggered_from.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps(inputs),
status=WorkflowRunStatus.RUNNING.value,
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by=user.id
workflow_run = WorkflowRun()
workflow_run.tenant_id = self._workflow.tenant_id
workflow_run.app_id = self._workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = self._workflow.id
workflow_run.type = self._workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = self._workflow.version
workflow_run.graph = self._workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING.value
workflow_run.created_by_role = (
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
)
workflow_run.created_by = self._user.id
db.session.add(workflow_run)
db.session.commit()
@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_run
def _workflow_run_success(
self, workflow_run: WorkflowRun,
def _handle_workflow_run_success(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Optional[str] = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = outputs
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
if trace_manager:
trace_manager.add_trace_task(
@ -135,34 +149,58 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
)
)
db.session.close()
return workflow_run
def _workflow_run_failed(
self, workflow_run: WorkflowRun,
def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
).all()
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
@ -178,39 +216,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_run
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
"""
Init workflow node execution from workflow run
:param workflow_run: workflow run
:param node_id: node id
:param node_type: node type
:param node_title: node title
:param node_run_index: run index
:param predecessor_node_id: predecessor node id if exists
:return:
"""
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.add(workflow_node_execution)
db.session.commit()
@ -219,33 +242,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
"""
Workflow node execution success
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param inputs: inputs
:param process_data: process data
:param outputs: outputs
:param execution_metadata: execution metadata
:param event: queue node succeeded event
:return:
"""
inputs = WorkflowEngineManager.handle_special_values(inputs)
outputs = WorkflowEngineManager.handle_special_values(outputs)
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_node_execution)
@ -253,33 +269,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
error: str,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None
) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param error: error message
:param event: queue node failed event
:return:
"""
inputs = WorkflowEngineManager.handle_special_values(inputs)
outputs = WorkflowEngineManager.handle_special_values(outputs)
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.error = event.error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_node_execution)
@ -287,8 +294,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_start_to_stream_response(self, task_id: str,
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun
) -> WorkflowStartStreamResponse:
"""
Workflow start to stream response.
:param task_id: task id
@ -302,13 +314,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict,
created_at=int(workflow_run.created_at.timestamp())
)
inputs=workflow_run.inputs_dict or {},
created_at=int(workflow_run.created_at.timestamp()),
),
)
def _workflow_finish_to_stream_response(self, task_id: str,
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
def _workflow_finish_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun
) -> WorkflowFinishStreamResponse:
"""
Workflow finish to stream response.
:param task_id: task id
@ -320,16 +333,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
created_by_account = workflow_run.created_by_account
if created_by_account:
created_by = {
"id": created_by_account.id,
"name": created_by_account.name,
"email": created_by_account.email,
'id': created_by_account.id,
'name': created_by_account.name,
'email': created_by_account.email,
}
else:
created_by_end_user = workflow_run.created_by_end_user
if created_by_end_user:
created_by = {
"id": created_by_end_user.id,
"user": created_by_end_user.session_id,
'id': created_by_end_user.id,
'user': created_by_end_user.session_id,
}
return WorkflowFinishStreamResponse(
@ -348,14 +361,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
)
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
),
)
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution) \
-> NodeStartStreamResponse:
def _workflow_node_start_to_stream_response(
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
) -> Optional[NodeStartStreamResponse]:
"""
Workflow node start to stream response.
:param event: queue node started event
@ -363,6 +375,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@ -374,8 +389,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
created_at=int(workflow_node_execution.created_at.timestamp())
)
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
),
)
# extras logic
@ -384,19 +404,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
response.data.extras['icon'] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id
provider_id=node_data.provider_id,
)
return response
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
-> NodeFinishStreamResponse:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@ -416,181 +444,155 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
"""
Workflow parallel branch start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run started event
:return:
"""
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
created_at=int(time.time()),
)
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
) -> ParallelBranchFinishedStreamResponse:
"""
Workflow parallel branch finished to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run succeeded or failed event
:return:
"""
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
)
)
def _handle_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run(
workflow=self._workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN,
user=self._user,
user_inputs=self._application_generate_entity.inputs,
system_inputs=self._workflow_system_variables
def _workflow_iteration_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
"""
Workflow iteration start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration start event
:return:
"""
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
predecessor_node_id=event.predecessor_node_id
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
"""
Workflow iteration next to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration next event
:return:
"""
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=event.node_type,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
db.session.close()
return workflow_node_execution
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
if self._iteration_state and self._iteration_state.current_iterations:
if not execution_metadata:
execution_metadata = {}
current_iteration_data = None
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if data.parent_iteration_id == None:
current_iteration_data = data
break
if current_iteration_data:
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
"""
Workflow iteration completed to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration completed event
:return:
"""
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
execution_metadata=execution_metadata
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
if self._iteration_state:
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
usage_dict = outputs.get('usage', {})
self._task_state.metadata['usage'] = usage_dict
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
execution_metadata=execution_metadata
)
db.session.close()
return workflow_node_execution
def _handle_workflow_finished(
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Optional[WorkflowRun]:
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
if not workflow_run:
return None
if conversation_id is None:
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.',
conversation_id=conversation_id,
trace_manager=trace_manager
)
latest_node_execution_info = self._task_state.latest_node_execution_info
if latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
if (workflow_node_execution
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=latest_node_execution_info.start_at,
error='Workflow stopped.'
)
elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=conversation_id,
trace_manager=trace_manager
)
else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=outputs,
conversation_id=conversation_id,
trace_manager=trace_manager
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
"""
@ -647,3 +649,40 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == workflow_run_id).first()
if not workflow_run:
raise Exception(f'Workflow run not found: {workflow_run_id}')
return workflow_run
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
"""
Refetch workflow node execution
:param node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
WorkflowNodeExecution.workflow_id == self._workflow.id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.node_execution_id == node_execution_id,
)
.first()
)
if not workflow_node_execution:
raise Exception(f'Workflow node execution not found: {node_execution_id}')
return workflow_node_execution

View File

@ -1,16 +0,0 @@
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariableKey
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
class WorkflowCycleStateManager:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariableKey, Any]

View File

@ -1,290 +0,0 @@
import json
import time
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
)
from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeExecutionInfo,
WorkflowIterationState,
)
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.workflow import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
_iteration_state: WorkflowIterationState = None
def _init_iteration_state(self) -> WorkflowIterationState:
if not self._iteration_state:
self._iteration_state = WorkflowIterationState(
current_iterations={}
)
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
"""
Handle iteration to stream response
:param task_id: task id
:param event: iteration event
:return:
"""
if isinstance(event, QueueIterationStartEvent):
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs,
metadata=event.metadata
)
)
elif isinstance(event, QueueIterationNextEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={}
)
)
elif isinstance(event, QueueIterationCompletedEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
execution_metadata={
'total_tokens': current_iteration.total_tokens,
},
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)
def _init_iteration_execution_from_workflow_run(self,
workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
inputs=json.dumps(inputs) if inputs else None,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
execution_metadata=json.dumps({
'started_run_index': node_run_index + 1,
'current_index': 0,
'steps_boundary': [],
}),
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
if isinstance(event, QueueIterationStartEvent):
return self._handle_iteration_started(event)
elif isinstance(event, QueueIterationNextEvent):
return self._handle_iteration_next(event)
elif isinstance(event, QueueIterationCompletedEvent):
return self._handle_iteration_completed(event)
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
self._init_iteration_state()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=NodeType.ITERATION,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
parent_iteration_id=None,
iteration_id=event.node_id,
current_index=0,
iteration_steps_boundary=[],
node_execution_id=workflow_node_execution.id,
started_at=time.perf_counter(),
inputs=event.inputs,
total_tokens=0,
node_data=event.node_data
)
db.session.close()
return workflow_node_execution
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
current_iteration.current_index = event.index
current_iteration.iteration_steps_boundary.append(event.node_run_index)
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['current_index'] = event.index
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
db.session.close()
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
# remove current iteration
self._iteration_state.current_iterations.pop(event.node_id, None)
# set latest node execution info
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.latest_node_execution_info = latest_node_execution_info
db.session.close()
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
"""
Handle iteration exception
"""
if not self._iteration_state or not self._iteration_state.current_iterations:
return
for node_id, current_iteration in self._iteration_state.current_iterations.items():
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
db.session.commit()
db.session.close()
yield IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=node_id,
node_id=node_id,
node_type=NodeType.ITERATION.value,
title=current_iteration.node_data.title,
outputs={},
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.FAILED,
error=error,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
execution_metadata={
'total_tokens': current_iteration.total_tokens,
},
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)