mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
refactor workflow runner
This commit is contained in:
@ -11,7 +11,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
@ -123,11 +123,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
return self._handle_response(
|
||||
return self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
@ -159,7 +161,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
@ -177,33 +179,40 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _handle_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
stream: bool = False) -> Union[dict, Generator]:
|
||||
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) -> Union[dict, Generator]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process(stream=stream)
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise ConversationTaskStoppedException()
|
||||
raise GenerateTaskStoppedException()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
@ -8,16 +8,14 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
from models.model import App, Conversation, Message
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -46,7 +44,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
@ -74,19 +72,10 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
):
|
||||
return
|
||||
|
||||
# fetch user
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]:
|
||||
user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first()
|
||||
else:
|
||||
user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=user,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
@ -99,6 +88,20 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
)]
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
|
||||
@ -4,9 +4,10 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
@ -16,25 +17,35 @@ from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message, MessageFile
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -47,41 +58,63 @@ class TaskState(BaseModel):
|
||||
answer: str = ""
|
||||
metadata: dict = {}
|
||||
usage: LLMUsage
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
workflow_run: Optional[WorkflowRun] = None
|
||||
start_at: Optional[float] = None
|
||||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
current_node_execution: Optional[WorkflowNodeExecution] = None
|
||||
current_node_execution_start_at: Optional[float] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._queue_manager = queue_manager
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._user = user
|
||||
self._task_state = TaskState(
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
def process(self, stream: bool) -> Union[dict, Generator]:
|
||||
def process(self) -> Union[dict, Generator]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
if self._stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
return self._process_blocking_response()
|
||||
@ -112,22 +145,17 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
self._task_state.answer = annotation.content
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
self._task_state.workflow_run_id = event.workflow_run_id
|
||||
elif isinstance(event, QueueNodeFinishedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueWorkflowFinishedEvent):
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
|
||||
self._on_workflow_start()
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
self._on_node_start(event)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
self._on_node_finished(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
self._on_workflow_finished(event)
|
||||
workflow_run = self._task_state.workflow_run
|
||||
|
||||
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
|
||||
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
@ -173,8 +201,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
self._on_workflow_start()
|
||||
workflow_run = self._task_state.workflow_run
|
||||
|
||||
response = {
|
||||
'event': 'workflow_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
@ -188,7 +217,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
self._on_node_start(event)
|
||||
workflow_node_execution = self._task_state.current_node_execution
|
||||
|
||||
response = {
|
||||
'event': 'node_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
@ -204,8 +235,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeFinishedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
self._on_node_finished(event)
|
||||
workflow_node_execution = self._task_state.current_node_execution
|
||||
|
||||
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
@ -234,16 +267,11 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
|
||||
else:
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
self._on_workflow_finished(event)
|
||||
workflow_run = self._task_state.workflow_run
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs_dict
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
data = self._error_to_stream_response_data(self._handle_error(err_event))
|
||||
yield self._yield_response(data)
|
||||
@ -252,7 +280,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
workflow_run_response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'workflow_run_id': workflow_run.id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
@ -390,6 +418,102 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
else:
|
||||
continue
|
||||
|
||||
def _on_workflow_start(self) -> None:
|
||||
self._task_state.start_at = time.perf_counter()
|
||||
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=self._workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=self._user,
|
||||
user_inputs=self._application_generate_entity.inputs,
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: self._message.query,
|
||||
SystemVariable.FILES: self._application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION: self._conversation.id,
|
||||
}
|
||||
)
|
||||
|
||||
self._task_state.workflow_run = workflow_run
|
||||
|
||||
def _on_node_start(self, event: QueueNodeStartedEvent) -> None:
|
||||
workflow_node_execution = self._init_node_execution_from_workflow_run(
|
||||
workflow_run=self._task_state.workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
self._task_state.current_node_execution = workflow_node_execution
|
||||
self._task_state.current_node_execution_start_at = time.perf_counter()
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=self._task_state.current_node_execution,
|
||||
start_at=self._task_state.current_node_execution_start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=event.execution_metadata
|
||||
)
|
||||
|
||||
if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=self._task_state.current_node_execution,
|
||||
start_at=self._task_state.current_node_execution_start_at,
|
||||
error=event.error
|
||||
)
|
||||
|
||||
self._task_state.current_node_execution = workflow_node_execution
|
||||
|
||||
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=self._task_state.workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error='Workflow stopped.'
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=self._task_state.workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error
|
||||
)
|
||||
else:
|
||||
workflow_run = self._workflow_run_success(
|
||||
workflow_run=self._task_state.workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=self._task_state.current_node_execution.outputs
|
||||
if self._task_state.current_node_execution else None
|
||||
)
|
||||
|
||||
self._task_state.workflow_run = workflow_run
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs_dict
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
|
||||
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Get workflow run.
|
||||
@ -397,11 +521,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
if workflow_run:
|
||||
# Because the workflow_run will be modified in the sub-thread,
|
||||
# and the first query in the main thread will cache the entity,
|
||||
# you need to expire the entity after the query
|
||||
db.session.expire(workflow_run)
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
|
||||
@ -412,11 +531,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""
|
||||
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
|
||||
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
|
||||
if workflow_node_execution:
|
||||
# Because the workflow_node_execution will be modified in the sub-thread,
|
||||
# and the first query in the main thread will cache the entity,
|
||||
# you need to expire the entity after the query
|
||||
db.session.expire(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _save_message(self) -> None:
|
||||
@ -428,7 +542,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.workflow_run_id = self._task_state.workflow_run_id
|
||||
self._message.workflow_run_id = self._task_state.workflow_run.id
|
||||
|
||||
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||
|
||||
@ -1,14 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
@ -17,39 +22,91 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
self._queue_manager = queue_manager
|
||||
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
|
||||
|
||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run finished
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute finished
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelCo
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
@ -177,7 +177,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
||||
@ -11,11 +11,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessage,
|
||||
QueueMessageEndEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -103,22 +100,16 @@ class AppQueueManager:
|
||||
:return:
|
||||
"""
|
||||
self._check_for_sqlalchemy_models(event.dict())
|
||||
|
||||
message = self.construct_queue_message(event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowFinishedEvent):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise ConversationTaskStoppedException()
|
||||
self._publish(event, pub_from)
|
||||
|
||||
@abstractmethod
|
||||
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -182,5 +173,5 @@ class AppQueueManager:
|
||||
"that cause thread safety issues is not allowed.")
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
class GenerateTaskStoppedException(Exception):
|
||||
pass
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
@ -177,7 +177,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
@ -166,7 +166,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
||||
@ -7,7 +7,7 @@ from sqlalchemy import and_
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
|
||||
from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
@ -60,7 +60,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
return generate_task_pipeline.process(stream=stream)
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise ConversationTaskStoppedException()
|
||||
raise GenerateTaskStoppedException()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueErrorEvent,
|
||||
QueueMessage,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
|
||||
|
||||
@ -28,3 +33,31 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = MessageQueueMessage(
|
||||
task_id=self._task_id,
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
@ -95,7 +95,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# return response or stream generator
|
||||
return self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
@ -117,7 +119,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
@ -136,19 +138,25 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
db.session.remove()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) -> Union[dict, Generator]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
@ -156,7 +164,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise ConversationTaskStoppedException()
|
||||
raise GenerateTaskStoppedException()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueMessage,
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
|
||||
@ -16,9 +20,27 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
|
||||
return WorkflowQueueMessage(
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(
|
||||
task_id=self._task_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
|
||||
@ -16,9 +15,8 @@ from core.moderation.input_moderation import InputModeration
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -43,7 +41,7 @@ class WorkflowAppRunner:
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
@ -59,19 +57,10 @@ class WorkflowAppRunner:
|
||||
):
|
||||
return
|
||||
|
||||
# fetch user
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]:
|
||||
user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first()
|
||||
else:
|
||||
user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=user,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files
|
||||
@ -82,6 +71,20 @@ class WorkflowAppRunner:
|
||||
)]
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: WorkflowAppGenerateEntity,
|
||||
|
||||
@ -4,28 +4,35 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -36,24 +43,44 @@ class TaskState(BaseModel):
|
||||
"""
|
||||
answer: str = ""
|
||||
metadata: dict = {}
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
workflow_run: Optional[WorkflowRun] = None
|
||||
start_at: Optional[float] = None
|
||||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
current_node_execution: Optional[WorkflowNodeExecution] = None
|
||||
current_node_execution_start_at: Optional[float] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param user: user
|
||||
:param stream: is stream
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._queue_manager = queue_manager
|
||||
self._user = user
|
||||
self._task_state = TaskState()
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
@ -79,17 +106,15 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
|
||||
else:
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs_dict
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
self._on_workflow_start()
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
self._on_node_start(event)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
self._on_node_finished(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
self._on_workflow_finished(event)
|
||||
workflow_run = self._task_state.workflow_run
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
@ -100,10 +125,12 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
public_event=False
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log()
|
||||
|
||||
response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'workflow_run_id': workflow_run.id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
@ -135,8 +162,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
self._task_state.workflow_run_id = event.workflow_run_id
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
self._on_workflow_start()
|
||||
workflow_run = self._task_state.workflow_run
|
||||
|
||||
response = {
|
||||
'event': 'workflow_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
@ -150,7 +178,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
self._on_node_start(event)
|
||||
workflow_node_execution = self._task_state.current_node_execution
|
||||
|
||||
response = {
|
||||
'event': 'node_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
@ -166,8 +196,10 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeFinishedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
self._on_node_finished(event)
|
||||
workflow_node_execution = self._task_state.current_node_execution
|
||||
|
||||
response = {
|
||||
'event': 'node_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
@ -190,20 +222,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
|
||||
else:
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs_dict
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
data = self._error_to_stream_response_data(self._handle_error(err_event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
self._on_workflow_finished(event)
|
||||
workflow_run = self._task_state.workflow_run
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
@ -219,7 +240,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
replace_response = {
|
||||
'event': 'text_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'workflow_run_id': self._task_state.workflow_run.id,
|
||||
'data': {
|
||||
'text': self._task_state.answer
|
||||
}
|
||||
@ -233,7 +254,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
workflow_run_response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'workflow_run_id': workflow_run.id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
@ -244,7 +265,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
'total_tokens': workflow_run.total_tokens,
|
||||
'total_steps': workflow_run.total_steps,
|
||||
'created_at': int(workflow_run.created_at.timestamp()),
|
||||
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||
'finished_at': int(workflow_run.finished_at.timestamp()) if workflow_run.finished_at else None
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,7 +300,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
response = {
|
||||
'event': 'text_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'workflow_run_id': self._task_state.workflow_run.id,
|
||||
'data': {
|
||||
'text': event.text
|
||||
}
|
||||
@ -291,6 +312,95 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
else:
|
||||
continue
|
||||
|
||||
def _on_workflow_start(self) -> None:
|
||||
self._task_state.start_at = time.perf_counter()
|
||||
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=self._workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=self._user,
|
||||
user_inputs=self._application_generate_entity.inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: self._application_generate_entity.files
|
||||
}
|
||||
)
|
||||
|
||||
self._task_state.workflow_run = workflow_run
|
||||
|
||||
def _on_node_start(self, event: QueueNodeStartedEvent) -> None:
|
||||
workflow_node_execution = self._init_node_execution_from_workflow_run(
|
||||
workflow_run=self._task_state.workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
self._task_state.current_node_execution = workflow_node_execution
|
||||
self._task_state.current_node_execution_start_at = time.perf_counter()
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=self._task_state.current_node_execution,
|
||||
start_at=self._task_state.current_node_execution_start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=event.execution_metadata
|
||||
)
|
||||
|
||||
if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=self._task_state.current_node_execution,
|
||||
start_at=self._task_state.current_node_execution_start_at,
|
||||
error=event.error
|
||||
)
|
||||
|
||||
self._task_state.current_node_execution = workflow_node_execution
|
||||
|
||||
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=self._task_state.workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error='Workflow stopped.'
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=self._task_state.workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error
|
||||
)
|
||||
else:
|
||||
workflow_run = self._workflow_run_success(
|
||||
workflow_run=self._task_state.workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=self._task_state.current_node_execution.outputs
|
||||
if self._task_state.current_node_execution else None
|
||||
)
|
||||
|
||||
self._task_state.workflow_run = workflow_run
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs_dict
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
|
||||
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Get workflow run.
|
||||
@ -298,11 +408,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
if workflow_run:
|
||||
# Because the workflow_run will be modified in the sub-thread,
|
||||
# and the first query in the main thread will cache the entity,
|
||||
# you need to expire the entity after the query
|
||||
db.session.expire(workflow_run)
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
|
||||
@ -313,11 +418,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
|
||||
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
|
||||
if workflow_node_execution:
|
||||
# Because the workflow_node_execution will be modified in the sub-thread,
|
||||
# and the first query in the main thread will cache the entity,
|
||||
# you need to expire the entity after the query
|
||||
db.session.expire(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _save_workflow_app_log(self) -> None:
|
||||
@ -335,7 +435,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
response = {
|
||||
'event': 'text_chunk',
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'workflow_run_id': self._task_state.workflow_run.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'data': {
|
||||
'text': text
|
||||
@ -398,7 +498,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
return {
|
||||
'event': 'error',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
**data
|
||||
}
|
||||
|
||||
|
||||
@ -1,14 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
@ -17,39 +22,91 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
self._queue_manager = queue_manager
|
||||
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
|
||||
|
||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run finished
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute finished
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
|
||||
202
api/core/app/apps/workflow_based_generate_task_pipeline.py
Normal file
202
api/core/app/apps/workflow_based_generate_task_pipeline.py
Normal file
@ -0,0 +1,202 @@
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
CreatedByRole,
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowBasedGenerateTaskPipeline:
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
||||
"""
|
||||
Init workflow run
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:return:
|
||||
"""
|
||||
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
|
||||
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
|
||||
.filter(WorkflowRun.app_id == workflow.app_id) \
|
||||
.scalar() or 0
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
sequence_number=new_sequence_number,
|
||||
workflow_id=workflow.id,
|
||||
type=workflow.type,
|
||||
triggered_from=triggered_from.value,
|
||||
version=workflow.version,
|
||||
graph=workflow.graph,
|
||||
inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}),
|
||||
status=WorkflowRunStatus.RUNNING.value,
|
||||
created_by_role=(CreatedByRole.ACCOUNT.value
|
||||
if isinstance(user, Account) else CreatedByRole.END_USER.value),
|
||||
created_by=user.id
|
||||
)
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_success(self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[dict] = None) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:return:
|
||||
"""
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = outputs
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_failed(self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param status: status
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Init workflow node execution from workflow run
|
||||
:param workflow_run: workflow run
|
||||
:param node_id: node id
|
||||
:param node_type: node type
|
||||
:param node_title: node title
|
||||
:param node_run_index: run index
|
||||
:param predecessor_node_id: predecessor node id if exists
|
||||
:return:
|
||||
"""
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution success
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param inputs: inputs
|
||||
:param process_data: process data
|
||||
:param outputs: outputs
|
||||
:param execution_metadata: execution metadata
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
error: str) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
Reference in New Issue
Block a user