mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 06:58:05 +08:00
Merge branch 'main' into feat/mcp-06-18
This commit is contained in:
@ -447,6 +447,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
},
|
||||
)
|
||||
|
||||
@ -466,8 +468,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||
)
|
||||
@ -483,6 +483,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
):
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -538,6 +540,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
app=app,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -570,8 +574,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@ -584,7 +586,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message: message
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
@ -596,8 +597,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message=message,
|
||||
user=user,
|
||||
dialogue_count=self._dialogue_count,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
@ -23,8 +23,12 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -55,6 +59,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
app: App,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -68,11 +74,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._workflow = workflow
|
||||
self.system_user_id = system_user_id
|
||||
self._app = app
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def run(self):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
query=self.application_generate_entity.query,
|
||||
files=self.application_generate_entity.files,
|
||||
conversation_id=self.conversation.id,
|
||||
user_id=self.system_user_id,
|
||||
dialogue_count=self._dialogue_count,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
@ -89,7 +108,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
@ -114,17 +132,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
conversation_variables = self._initialize_conversation_variables()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
query=query,
|
||||
files=files,
|
||||
conversation_id=self.conversation.id,
|
||||
user_id=self.system_user_id,
|
||||
dialogue_count=self._dialogue_count,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
@ -172,6 +179,23 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
self._queue_manager.graph_runtime_state = graph_runtime_state
|
||||
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType(self._workflow.type),
|
||||
version=self._workflow.version,
|
||||
graph_data=self._workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
|
||||
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
@ -60,25 +61,21 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
@ -93,8 +90,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
dialogue_count: int,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
):
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
@ -114,31 +109,20 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables=SystemVariable(
|
||||
query=message.query,
|
||||
files=application_generate_entity.files,
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_session_id,
|
||||
dialogue_count=dialogue_count,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_run_id,
|
||||
),
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
self._workflow_system_variables = SystemVariable(
|
||||
query=message.query,
|
||||
files=application_generate_entity.files,
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_session_id,
|
||||
dialogue_count=dialogue_count,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=self._workflow_system_variables,
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
@ -157,6 +141,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -289,12 +275,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
|
||||
"""Fluent validation for graph runtime state."""
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
return graph_runtime_state
|
||||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline.ping_stream_response()
|
||||
@ -305,21 +285,28 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
def _handle_workflow_started_event(
|
||||
self,
|
||||
event: QueueWorkflowStartedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
run_id = self._extract_workflow_run_id(runtime_state)
|
||||
self._workflow_run_id = run_id
|
||||
|
||||
with self._database_session() as session:
|
||||
message = self._get_message(session=session)
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
|
||||
message.workflow_run_id = workflow_execution.id_
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
message.workflow_run_id = run_id
|
||||
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
|
||||
yield workflow_start_resp
|
||||
|
||||
@ -327,13 +314,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_retry_resp:
|
||||
@ -345,14 +328,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""Handle node started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
|
||||
node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_start_resp:
|
||||
@ -368,14 +346,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
self._save_output_for_event(event, event.node_execution_id)
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
@ -386,16 +362,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event)
|
||||
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
self._save_output_for_event(event, event.node_execution_id)
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
@ -505,29 +478,19 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowSucceededEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow succeeded events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow_id,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
@ -536,30 +499,20 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowPartialSuccessEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow partial success events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow_id,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
graph_runtime_state=validated_state,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
@ -568,32 +521,25 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowFailedEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow failed events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow_id,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
graph_runtime_state=validated_state,
|
||||
error=event.error,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
error_message=event.error,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}"))
|
||||
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
@ -608,25 +554,23 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle stop events."""
|
||||
if self._workflow_run_id and graph_runtime_state:
|
||||
_ = trace_manager
|
||||
resolved_state = None
|
||||
if self._workflow_run_id:
|
||||
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
|
||||
|
||||
if self._workflow_run_id and resolved_state:
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow_id,
|
||||
status=WorkflowExecutionStatus.STOPPED,
|
||||
graph_runtime_state=resolved_state,
|
||||
error=event.get_stop_reason(),
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.get_stop_reason(),
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
@ -648,7 +592,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
resolved_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
|
||||
self._task_state.answer
|
||||
@ -662,7 +606,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@ -671,10 +615,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle retriever resources events."""
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
|
||||
with self._database_session() as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
return
|
||||
yield # Make this a generator
|
||||
|
||||
@ -683,10 +623,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle annotation reply events."""
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
|
||||
with self._database_session() as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
return
|
||||
yield # Make this a generator
|
||||
|
||||
@ -740,7 +676,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: Any,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
@ -753,7 +688,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if handler := handlers.get(event_type):
|
||||
yield from handler(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@ -770,7 +704,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
):
|
||||
yield from self._handle_node_failed_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@ -789,15 +722,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
Process stream response using elegant Fluent Python patterns.
|
||||
Maintains exact same functionality as original 57-if-statement version.
|
||||
"""
|
||||
# Initialize graph runtime state
|
||||
graph_runtime_state: GraphRuntimeState | None = None
|
||||
|
||||
for queue_message in self._base_task_pipeline.queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
match event:
|
||||
case QueueWorkflowStartedEvent():
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
self._resolve_graph_runtime_state()
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueErrorEvent():
|
||||
@ -805,15 +735,11 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
break
|
||||
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_event(
|
||||
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
|
||||
)
|
||||
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(
|
||||
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
|
||||
)
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
# Handle all other events through elegant dispatch
|
||||
@ -821,7 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if responses := list(
|
||||
self._dispatch_event(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@ -879,6 +804,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
if candidate is not None:
|
||||
self._graph_runtime_state = candidate
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
Message end to stream response.
|
||||
|
||||
@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -47,6 +48,7 @@ class AppQueueManager:
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
@ -109,6 +111,16 @@ class AppQueueManager:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(error=e), pub_from)
|
||||
|
||||
@property
|
||||
def graph_runtime_state(self) -> GraphRuntimeState | None:
|
||||
"""Retrieve the attached graph runtime state, if available."""
|
||||
return self._graph_runtime_state
|
||||
|
||||
@graph_runtime_state.setter
|
||||
def graph_runtime_state(self, graph_runtime_state: GraphRuntimeState | None) -> None:
|
||||
"""Attach the live graph runtime state reference for downstream consumers."""
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
55
api/core/app/apps/common/graph_runtime_state_support.py
Normal file
55
api/core/app/apps/common/graph_runtime_state_support.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Shared helpers for managing GraphRuntimeState across task pipelines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
|
||||
|
||||
class GraphRuntimeStateSupport:
|
||||
"""
|
||||
Mixin that centralises common GraphRuntimeState access patterns used by task pipelines.
|
||||
|
||||
Subclasses are expected to provide:
|
||||
* `_base_task_pipeline` – exposing the queue manager with an optional cached runtime state.
|
||||
* `_graph_runtime_state` attribute used as the local cache for the runtime state.
|
||||
"""
|
||||
|
||||
_base_task_pipeline: BasedGenerateTaskPipeline
|
||||
_graph_runtime_state: GraphRuntimeState | None = None
|
||||
|
||||
def _ensure_graph_runtime_initialized(
|
||||
self,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> GraphRuntimeState:
|
||||
"""Validate and return the active graph runtime state."""
|
||||
return self._resolve_graph_runtime_state(graph_runtime_state)
|
||||
|
||||
def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str:
|
||||
system_variables = graph_runtime_state.variable_pool.system_variables
|
||||
if not system_variables or not system_variables.workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id missing from runtime state")
|
||||
return str(system_variables.workflow_execution_id)
|
||||
|
||||
def _resolve_graph_runtime_state(
|
||||
self,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> GraphRuntimeState:
|
||||
"""Return the cached runtime state or bootstrap it from the queue manager."""
|
||||
if graph_runtime_state is not None:
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
return graph_runtime_state
|
||||
|
||||
if self._graph_runtime_state is None:
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
if candidate is not None:
|
||||
self._graph_runtime_state = candidate
|
||||
|
||||
if self._graph_runtime_state is None:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
return self._graph_runtime_state
|
||||
@ -1,9 +1,8 @@
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
@ -39,16 +38,36 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
Account,
|
||||
EndUser,
|
||||
)
|
||||
from models import Account, EndUser
|
||||
from services.variable_truncator import VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeSnapshot:
|
||||
"""In-memory cache for node metadata between start and completion events."""
|
||||
|
||||
title: str
|
||||
index: int
|
||||
start_at: datetime
|
||||
iteration_id: str = ""
|
||||
"""Empty string means the node is not executing inside an iteration."""
|
||||
loop_id: str = ""
|
||||
"""Empty string means the node is not executing inside a loop."""
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
def __init__(
|
||||
@ -56,37 +75,151 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
user: Union[Account, EndUser],
|
||||
system_variables: SystemVariable,
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._user = user
|
||||
self._system_variables = system_variables
|
||||
self._workflow_inputs = self._prepare_workflow_inputs()
|
||||
self._truncator = VariableTruncator.default()
|
||||
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
|
||||
self._workflow_execution_id: str | None = None
|
||||
self._workflow_started_at: datetime | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Workflow lifecycle helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
|
||||
inputs = dict(self._application_generate_entity.inputs)
|
||||
for field_name, value in self._system_variables.to_dict().items():
|
||||
# TODO(@future-refactor): store system variables separately from user inputs so we don't
|
||||
# need to flatten `sys.*` entries into the input payload just for rerun/export tooling.
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID:
|
||||
# Conversation IDs are session-scoped; omitting them keeps workflow inputs
|
||||
# reusable without pinning new runs to a prior conversation.
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
handled = WorkflowEntry.handle_special_values(inputs)
|
||||
return dict(handled or {})
|
||||
|
||||
def _ensure_workflow_run_id(self, workflow_run_id: str | None = None) -> str:
|
||||
"""Return the memoized workflow run id, optionally seeding it during start events."""
|
||||
if workflow_run_id is not None:
|
||||
self._workflow_execution_id = workflow_run_id
|
||||
if not self._workflow_execution_id:
|
||||
raise ValueError("workflow_run_id missing before streaming workflow events")
|
||||
return self._workflow_execution_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Node snapshot helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _store_snapshot(self, event: QueueNodeStartedEvent) -> _NodeSnapshot:
|
||||
snapshot = _NodeSnapshot(
|
||||
title=event.node_title,
|
||||
index=event.node_run_index,
|
||||
start_at=event.start_at,
|
||||
iteration_id=event.in_iteration_id or "",
|
||||
loop_id=event.in_loop_id or "",
|
||||
)
|
||||
node_execution_id = NodeExecutionId(event.node_execution_id)
|
||||
self._node_snapshots[node_execution_id] = snapshot
|
||||
return snapshot
|
||||
|
||||
def _get_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
|
||||
return self._node_snapshots.get(NodeExecutionId(node_execution_id))
|
||||
|
||||
def _pop_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
|
||||
return self._node_snapshots.pop(NodeExecutionId(node_execution_id), None)
|
||||
|
||||
@staticmethod
|
||||
def _merge_metadata(
|
||||
base_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None,
|
||||
snapshot: _NodeSnapshot | None,
|
||||
) -> Mapping[WorkflowNodeExecutionMetadataKey, Any] | None:
|
||||
if not base_metadata and not snapshot:
|
||||
return base_metadata
|
||||
|
||||
merged: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
if base_metadata:
|
||||
merged.update(base_metadata)
|
||||
|
||||
if snapshot:
|
||||
if snapshot.iteration_id:
|
||||
merged[WorkflowNodeExecutionMetadataKey.ITERATION_ID] = snapshot.iteration_id
|
||||
if snapshot.loop_id:
|
||||
merged[WorkflowNodeExecutionMetadataKey.LOOP_ID] = snapshot.loop_id
|
||||
|
||||
return merged or None
|
||||
|
||||
def _truncate_mapping(
|
||||
self,
|
||||
mapping: Mapping[str, Any] | None,
|
||||
) -> tuple[Mapping[str, Any] | None, bool]:
|
||||
if mapping is None:
|
||||
return None, False
|
||||
if not mapping:
|
||||
return {}, False
|
||||
|
||||
normalized = WorkflowEntry.handle_special_values(dict(mapping))
|
||||
if normalized is None:
|
||||
return None, False
|
||||
|
||||
truncated, is_truncated = self._truncator.truncate_variable_mapping(dict(normalized))
|
||||
return truncated, is_truncated
|
||||
|
||||
@staticmethod
|
||||
def _encode_outputs(outputs: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
if outputs is None:
|
||||
return None
|
||||
converter = WorkflowRuntimeTypeConverter()
|
||||
return converter.to_json_encodable(outputs)
|
||||
|
||||
def workflow_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution: WorkflowExecution,
|
||||
workflow_run_id: str,
|
||||
workflow_id: str,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
run_id = self._ensure_workflow_run_id(workflow_run_id)
|
||||
started_at = naive_utc_now()
|
||||
self._workflow_started_at = started_at
|
||||
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
inputs=workflow_execution.inputs,
|
||||
created_at=int(workflow_execution.started_at.timestamp()),
|
||||
id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
inputs=self._workflow_inputs,
|
||||
created_at=int(started_at.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_execution: WorkflowExecution,
|
||||
workflow_id: str,
|
||||
status: WorkflowExecutionStatus,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
error: str | None = None,
|
||||
exceptions_count: int = 0,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
started_at = self._workflow_started_at
|
||||
if started_at is None:
|
||||
raise ValueError(
|
||||
"workflow_finish_to_stream_response called before workflow_start_to_stream_response",
|
||||
)
|
||||
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - started_at).total_seconds()
|
||||
|
||||
outputs_mapping = graph_runtime_state.outputs or {}
|
||||
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
|
||||
|
||||
created_by: Mapping[str, object] | None
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
@ -94,38 +227,29 @@ class WorkflowResponseConverter:
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
elif isinstance(user, EndUser):
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
# Handle the case where finished_at is None by using current time as default
|
||||
finished_at_timestamp = (
|
||||
int(workflow_execution.finished_at.timestamp())
|
||||
if workflow_execution.finished_at
|
||||
else int(datetime.now(UTC).timestamp())
|
||||
)
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
status=workflow_execution.status,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
|
||||
error=workflow_execution.error_message,
|
||||
elapsed_time=workflow_execution.elapsed_time,
|
||||
total_tokens=workflow_execution.total_tokens,
|
||||
total_steps=workflow_execution.total_steps,
|
||||
id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
status=status.value,
|
||||
outputs=encoded_outputs,
|
||||
error=error,
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_execution.started_at.timestamp()),
|
||||
finished_at=finished_at_timestamp,
|
||||
files=self.fetch_files_from_node_outputs(workflow_execution.outputs),
|
||||
exceptions_count=workflow_execution.exceptions_count,
|
||||
created_at=int(started_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(outputs_mapping),
|
||||
exceptions_count=exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
@ -134,38 +258,28 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> NodeStartStreamResponse | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
snapshot = self._store_snapshot(event)
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
title=workflow_node_execution.title,
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=snapshot.title,
|
||||
index=snapshot.index,
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
@ -189,41 +303,54 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> NodeFinishStreamResponse | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
snapshot = self._pop_snapshot(event.node_execution_id)
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
start_at = snapshot.start_at if snapshot else event.start_at
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - start_at).total_seconds()
|
||||
|
||||
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
|
||||
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
|
||||
encoded_outputs = self._encode_outputs(event.outputs)
|
||||
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
|
||||
metadata = self._merge_metadata(event.execution_metadata, snapshot)
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
error_message = event.error
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
error_message = event.error
|
||||
else:
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
error_message = event.error
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=run_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
process_data=workflow_node_execution.get_response_process_data(),
|
||||
process_data_truncated=workflow_node_execution.process_data_truncated,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
|
||||
outputs_truncated=workflow_node_execution.outputs_truncated,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.metadata,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
index=snapshot.index if snapshot else 0,
|
||||
title=snapshot.title if snapshot else "",
|
||||
inputs=inputs,
|
||||
inputs_truncated=inputs_truncated,
|
||||
process_data=process_data,
|
||||
process_data_truncated=process_data_truncated,
|
||||
outputs=outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
status=status,
|
||||
error=error_message,
|
||||
elapsed_time=elapsed_time,
|
||||
execution_metadata=metadata,
|
||||
created_at=int(start_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
@ -234,44 +361,45 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
) -> NodeRetryStreamResponse | None:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
snapshot = self._get_snapshot(event.node_execution_id)
|
||||
if snapshot is None:
|
||||
raise AssertionError("node retry event arrived without a stored snapshot")
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
|
||||
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
|
||||
encoded_outputs = self._encode_outputs(event.outputs)
|
||||
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
|
||||
metadata = self._merge_metadata(event.execution_metadata, snapshot)
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
process_data=workflow_node_execution.get_response_process_data(),
|
||||
process_data_truncated=workflow_node_execution.process_data_truncated,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
|
||||
outputs_truncated=workflow_node_execution.outputs_truncated,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.metadata,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
index=snapshot.index,
|
||||
title=snapshot.title,
|
||||
inputs=inputs,
|
||||
inputs_truncated=inputs_truncated,
|
||||
process_data=process_data,
|
||||
process_data_truncated=process_data_truncated,
|
||||
outputs=outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
status=WorkflowNodeExecutionStatus.RETRY.value,
|
||||
error=event.error,
|
||||
elapsed_time=elapsed_time,
|
||||
execution_metadata=metadata,
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
retry_index=event.retry_index,
|
||||
@ -379,8 +507,6 @@ class WorkflowResponseConverter:
|
||||
inputs=new_inputs,
|
||||
inputs_truncated=truncated,
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -405,9 +531,6 @@ class WorkflowResponseConverter:
|
||||
pre_loop_output={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -446,8 +569,6 @@ class WorkflowResponseConverter:
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = {}
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@ -207,6 +207,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
app_mode=app_config.app_mode,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
|
||||
@ -352,6 +352,8 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
},
|
||||
)
|
||||
|
||||
@ -367,8 +369,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
@ -573,6 +573,8 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -620,6 +622,8 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
@ -648,8 +652,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@ -660,7 +662,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
:param queue_manager: queue manager
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
@ -670,8 +671,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
|
||||
@ -11,11 +11,14 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -40,6 +43,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -56,6 +61,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
@ -163,6 +170,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
self._queue_manager.graph_runtime_state = graph_runtime_state
|
||||
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
|
||||
@ -231,6 +231,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
},
|
||||
)
|
||||
|
||||
@ -244,8 +246,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=streaming,
|
||||
)
|
||||
@ -424,6 +424,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -465,6 +467,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -493,8 +497,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@ -514,8 +516,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@ -5,12 +5,13 @@ from typing import cast
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -34,6 +35,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -43,6 +46,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
@ -51,6 +56,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
@ -60,18 +73,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
@ -96,6 +100,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
self._queue_manager.graph_runtime_state = graph_runtime_state
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
@ -115,6 +121,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType(self._workflow.type),
|
||||
version=self._workflow.version,
|
||||
graph_data=self._workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
|
||||
@ -8,11 +8,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
@ -53,27 +51,20 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
)
|
||||
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
@ -85,8 +76,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
):
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
@ -99,42 +88,30 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self._user_id = user.id
|
||||
user_session_id = user.session_id
|
||||
self._created_by_role = CreatorUserRole.END_USER
|
||||
elif isinstance(user, Account):
|
||||
else:
|
||||
self._user_id = user.id
|
||||
user_session_id = user.id
|
||||
self._created_by_role = CreatorUserRole.ACCOUNT
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables=SystemVariable(
|
||||
files=application_generate_entity.files,
|
||||
user_id=user_session_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_execution_id,
|
||||
),
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_run_id = ""
|
||||
self._workflow_execution_id = ""
|
||||
self._invoke_from = queue_manager.invoke_from
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = SystemVariable(
|
||||
files=application_generate_entity.files,
|
||||
user_id=user_session_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=self._workflow_system_variables,
|
||||
)
|
||||
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -261,15 +238,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
"""Fluent validation for workflow state."""
|
||||
if not self._workflow_run_id:
|
||||
if not self._workflow_execution_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
|
||||
"""Fluent validation for graph runtime state."""
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
return graph_runtime_state
|
||||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline.ping_stream_response()
|
||||
@ -283,12 +254,14 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self, event: QueueWorkflowStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
|
||||
run_id = self._extract_workflow_run_id(runtime_state)
|
||||
self._workflow_execution_id = run_id
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
)
|
||||
yield start_resp
|
||||
|
||||
@ -296,14 +269,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
@ -315,13 +283,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
"""Handle node started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_start_response:
|
||||
@ -331,14 +295,12 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self, event: QueueNodeSucceededEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node succeeded events."""
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
self._save_output_for_event(event, event.node_execution_id)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
@ -349,17 +311,13 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
self._save_output_for_event(event, event.node_execution_id)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
@ -372,7 +330,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_start_resp
|
||||
@ -385,7 +343,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_next_resp
|
||||
@ -398,7 +356,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_finish_resp
|
||||
@ -409,7 +367,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_start_resp
|
||||
@ -420,7 +378,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_next_resp
|
||||
@ -433,7 +391,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_finish_resp
|
||||
@ -442,33 +400,22 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowSucceededEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow succeeded events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
@ -476,34 +423,23 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowPartialSuccessEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow partial success events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
graph_runtime_state=validated_state,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
@ -511,37 +447,33 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow failed and stop events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
|
||||
if isinstance(event, QueueWorkflowFailedEvent):
|
||||
status = WorkflowExecutionStatus.FAILED
|
||||
error = event.error
|
||||
exceptions_count = event.exceptions_count
|
||||
else:
|
||||
status = WorkflowExecutionStatus.STOPPED
|
||||
error = event.get_stop_reason()
|
||||
exceptions_count = 0
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=status,
|
||||
graph_runtime_state=validated_state,
|
||||
error=error,
|
||||
exceptions_count=exceptions_count,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
@ -601,7 +533,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: AppQueueEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
@ -614,7 +545,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if handler := handlers.get(event_type):
|
||||
yield from handler(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@ -631,7 +561,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
):
|
||||
yield from self._handle_node_failed_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@ -642,7 +571,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||
yield from self._handle_workflow_failed_and_stop_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@ -661,15 +589,12 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
Process stream response using elegant Fluent Python patterns.
|
||||
Maintains exact same functionality as original 44-if-statement version.
|
||||
"""
|
||||
# Initialize graph runtime state
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._base_task_pipeline.queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
match event:
|
||||
case QueueWorkflowStartedEvent():
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
self._resolve_graph_runtime_state()
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueTextChunkEvent():
|
||||
@ -681,12 +606,19 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
yield from self._handle_error_event(event)
|
||||
break
|
||||
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
break
|
||||
|
||||
# Handle all other events through elegant dispatch
|
||||
case _:
|
||||
if responses := list(
|
||||
self._dispatch_event(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@ -697,7 +629,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
@ -709,11 +641,14 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
# not save log for debugging
|
||||
return
|
||||
|
||||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
|
||||
workflow_app_log.workflow_id = workflow_execution.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_execution.id_
|
||||
workflow_app_log.workflow_id = self._workflow.id
|
||||
workflow_app_log.workflow_run_id = workflow_run_id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = self._created_by_role
|
||||
workflow_app_log.created_by = self._user_id
|
||||
|
||||
@ -25,7 +25,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
@ -54,6 +54,7 @@ from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -346,9 +347,7 @@ class WorkflowBasedAppRunner:
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
|
||||
)
|
||||
self._publish_event(QueueWorkflowStartedEvent())
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
@ -372,7 +371,6 @@ class WorkflowBasedAppRunner:
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
inputs=inputs,
|
||||
@ -393,7 +391,6 @@ class WorkflowBasedAppRunner:
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
@ -494,7 +491,6 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
@ -536,7 +532,6 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@ -3,11 +3,11 @@ from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@ -54,6 +54,7 @@ class AppQueueEvent(BaseModel):
|
||||
"""
|
||||
|
||||
event: QueueEvent
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class QueueLLMChunkEvent(AppQueueEvent):
|
||||
@ -80,7 +81,6 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@ -132,19 +132,10 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@ -160,16 +151,6 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Any = None # output for the current loop
|
||||
|
||||
@ -185,14 +166,6 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
@ -285,12 +258,9 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
|
||||
|
||||
|
||||
class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowStartedEvent entity
|
||||
"""
|
||||
"""QueueWorkflowStartedEvent entity."""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
@ -334,15 +304,9 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
node_title: str
|
||||
node_type: NodeType
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
predecessor_node_id: str | None = None
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
@ -360,14 +324,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
@ -423,14 +379,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
@ -455,7 +403,6 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
parallel_id: str | None = None
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
|
||||
@ -257,13 +257,8 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
inputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
parallel_run_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
@ -285,10 +280,6 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"inputs": None,
|
||||
"created_at": self.data.created_at,
|
||||
"extras": {},
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
},
|
||||
@ -324,10 +315,6 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
|
||||
@ -357,10 +344,6 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
},
|
||||
@ -396,10 +379,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
retry_index: int = 0
|
||||
@ -430,10 +409,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
@ -541,8 +516,6 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_STARTED
|
||||
workflow_run_id: str
|
||||
@ -567,9 +540,6 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
pre_loop_output: Any = None
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parallel_mode_run_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||
workflow_run_id: str
|
||||
@ -603,8 +573,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_COMPLETED
|
||||
workflow_run_id: str
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.common_entities import I18nObject
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
@ -18,11 +16,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_datasource_plugin_provider(
|
||||
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||
|
||||
@ -1148,6 +1148,15 @@ class ProviderConfiguration(BaseModel):
|
||||
raise ValueError("Can't add same credential")
|
||||
provider_model_record.credential_id = credential_record.id
|
||||
provider_model_record.updated_at = naive_utc_now()
|
||||
|
||||
# clear cache
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
session.add(provider_model_record)
|
||||
session.commit()
|
||||
|
||||
@ -1181,6 +1190,14 @@ class ProviderConfiguration(BaseModel):
|
||||
session.add(provider_model_record)
|
||||
session.commit()
|
||||
|
||||
# clear cache
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def delete_custom_model(self, model_type: ModelType, model: str):
|
||||
"""
|
||||
Delete custom model.
|
||||
|
||||
@ -913,4 +913,4 @@ class TraceQueueManager:
|
||||
"file_id": file_id,
|
||||
"app_id": task.app_id,
|
||||
}
|
||||
process_trace_tasks.delay(file_info)
|
||||
process_trace_tasks.delay(file_info) # type: ignore
|
||||
|
||||
@ -14,7 +14,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from models.model import App, AppMode, EndUser
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
@ -250,7 +250,6 @@ class WeaviateVector(BaseVector):
|
||||
)
|
||||
)
|
||||
|
||||
batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100))
|
||||
with col.batch.dynamic() as batch:
|
||||
for obj in objs:
|
||||
batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector)
|
||||
@ -348,7 +347,10 @@ class WeaviateVector(BaseVector):
|
||||
for obj in res.objects:
|
||||
properties = dict(obj.properties or {})
|
||||
text = properties.pop(Field.TEXT_KEY.value, "")
|
||||
distance = (obj.metadata.distance if obj.metadata else None) or 1.0
|
||||
if obj.metadata and obj.metadata.distance is not None:
|
||||
distance = obj.metadata.distance
|
||||
else:
|
||||
distance = 1.0
|
||||
score = 1.0 - distance
|
||||
|
||||
if score > score_threshold:
|
||||
|
||||
@ -25,7 +25,7 @@ class FirecrawlApp:
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/scrape", json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
data = response_data["data"]
|
||||
@ -42,7 +42,7 @@ class FirecrawlApp:
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
||||
job_id = response.json().get("id")
|
||||
@ -51,9 +51,25 @@ class FirecrawlApp:
|
||||
self._handle_error(response, "start crawl job")
|
||||
return "" # unreachable
|
||||
|
||||
def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map
|
||||
headers = self._prepare_headers()
|
||||
json_data: dict[str, Any] = {"url": url, "integration": "dify"}
|
||||
if params:
|
||||
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
return cast(dict[str, Any], response.json())
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "start map job")
|
||||
return {}
|
||||
else:
|
||||
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f"{self.base_url}/v1/crawl/{job_id}", headers)
|
||||
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
@ -135,12 +151,16 @@ class FirecrawlApp:
|
||||
"lang": "en",
|
||||
"country": "us",
|
||||
"timeout": 60000,
|
||||
"ignoreInvalidURLs": False,
|
||||
"ignoreInvalidURLs": True,
|
||||
"scrapeOptions": {},
|
||||
"sources": [
|
||||
{"type": "web"},
|
||||
],
|
||||
"integration": "dify",
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/search", json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if not response_data.get("success"):
|
||||
|
||||
@ -108,7 +108,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task (fire and forget)
|
||||
save_workflow_execution_task.delay(
|
||||
save_workflow_execution_task.delay( # type: ignore
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
|
||||
@ -104,7 +104,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# Initialize in-memory cache for node executions
|
||||
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
|
||||
|
||||
# Initialize FileService for handling offloaded data
|
||||
@ -332,17 +331,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
Args:
|
||||
execution: The NodeExecution domain entity to persist
|
||||
"""
|
||||
# NOTE: As per the implementation of `WorkflowCycleManager`,
|
||||
# the `save` method is invoked multiple times during the node's execution lifecycle, including:
|
||||
#
|
||||
# - When the node starts execution
|
||||
# - When the node retries execution
|
||||
# - When the node completes execution (either successfully or with failure)
|
||||
#
|
||||
# Only the final invocation will have `inputs` and `outputs` populated.
|
||||
#
|
||||
# This simplifies the logic for saving offloaded variables but introduces a tight coupling
|
||||
# between this module and `WorkflowCycleManager`.
|
||||
# NOTE: The workflow engine triggers `save` multiple times for a single node execution:
|
||||
# when the node starts, any time it retries, and once more when it reaches a terminal state.
|
||||
# Only the final call contains the complete inputs and outputs payloads, so earlier invocations
|
||||
# must tolerate missing data without attempting to offload variables.
|
||||
|
||||
# Convert domain model to database model using tenant context and other attributes
|
||||
db_model = self._to_db_model(execution)
|
||||
|
||||
@ -395,11 +395,13 @@ class ApiTool(Tool):
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
|
||||
# assemble invoke message based on response type
|
||||
if parsed_response.is_json and isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
if parsed_response.is_json:
|
||||
if isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
|
||||
# FIXES: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
|
||||
# We need never break the original flows
|
||||
# The yield below must be preserved to keep backward compatibility.
|
||||
#
|
||||
# ref: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
|
||||
yield self.create_text_message(response.text)
|
||||
else:
|
||||
# Convert to string if needed and create text message
|
||||
|
||||
@ -189,6 +189,11 @@ class ToolInvokeMessage(BaseModel):
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
|
||||
|
||||
@field_validator("metadata", mode="before")
|
||||
@classmethod
|
||||
def _normalize_metadata(cls, value: Mapping[str, Any] | None) -> Mapping[str, Any]:
|
||||
return value or {}
|
||||
|
||||
class RetrieverResourceMessage(BaseModel):
|
||||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
@ -376,6 +381,11 @@ class ToolEntity(BaseModel):
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
||||
@field_validator("output_schema", mode="before")
|
||||
@classmethod
|
||||
def _normalize_output_schema(cls, value: Mapping[str, object] | None) -> Mapping[str, object]:
|
||||
return value or {}
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(
|
||||
|
||||
@ -63,8 +63,8 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs.login import current_user
|
||||
from models.model import App
|
||||
from models import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -79,11 +81,16 @@ class WorkflowTool(Tool):
|
||||
generator = WorkflowAppGenerator()
|
||||
assert self.runtime is not None
|
||||
assert self.runtime.invoke_from is not None
|
||||
assert current_user is not None
|
||||
|
||||
user = self._resolve_user(user_id=user_id)
|
||||
|
||||
if user is None:
|
||||
raise ToolInvokeError("User not found")
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=current_user,
|
||||
user=user,
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
@ -123,6 +130,51 @@ class WorkflowTool(Tool):
|
||||
label=self.label,
|
||||
)
|
||||
|
||||
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
|
||||
"""
|
||||
Resolve user object in both HTTP and worker contexts.
|
||||
|
||||
In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser).
|
||||
In worker context: load Account from database by user_id (only returns Account, never EndUser).
|
||||
|
||||
Returns:
|
||||
Account | EndUser | None: The resolved user object, or None if resolution fails.
|
||||
"""
|
||||
if has_request_context():
|
||||
return self._resolve_user_from_request()
|
||||
else:
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
|
||||
def _resolve_user_from_request(self) -> Account | EndUser | None:
|
||||
"""
|
||||
Resolve user from Flask request context.
|
||||
"""
|
||||
try:
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
return getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve user from request context: %s", e)
|
||||
return None
|
||||
|
||||
def _resolve_user_from_database(self, user_id: str) -> Account | None:
|
||||
"""
|
||||
Resolve user from database (worker/Celery context).
|
||||
"""
|
||||
|
||||
user_stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(user_stmt)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
|
||||
tenant = db.session.scalar(tenant_stmt)
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
user.current_tenant = tenant
|
||||
|
||||
return user
|
||||
|
||||
def _get_workflow(self, app_id: str, version: str) -> Workflow:
|
||||
"""
|
||||
get the workflow by app id and version
|
||||
|
||||
@ -1,18 +1,11 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .run_condition import RunCondition
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"RunCondition",
|
||||
"VariablePool",
|
||||
"VariableValue",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
@ -1,160 +0,0 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
# Private attributes to prevent direct modification
|
||||
_variable_pool: VariablePool = PrivateAttr()
|
||||
_start_at: float = PrivateAttr()
|
||||
_total_tokens: int = PrivateAttr(default=0)
|
||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
|
||||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
_ready_queue_json: str = PrivateAttr()
|
||||
_graph_execution_json: str = PrivateAttr()
|
||||
_response_coordinator_json: str = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, object] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_json: str = "",
|
||||
graph_execution_json: str = "",
|
||||
response_coordinator_json: str = "",
|
||||
**kwargs: object,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Initialize private attributes with validation
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
if llm_usage is None:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
self._llm_usage = llm_usage
|
||||
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
self._outputs = deepcopy(outputs)
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
self._ready_queue_json = ready_queue_json
|
||||
self._graph_execution_json = graph_execution_json
|
||||
self._response_coordinator_json = response_coordinator_json
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
"""Get the variable pool."""
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time."""
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
"""Set the start time."""
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count."""
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int):
|
||||
"""Set the total tokens count."""
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get the LLM usage info."""
|
||||
# Return a copy to prevent external modification
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage):
|
||||
"""Set the LLM usage info."""
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
"""Get a copy of the outputs dictionary."""
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, object]) -> None:
|
||||
"""Set the outputs dictionary."""
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: object) -> None:
|
||||
"""Set a single output value."""
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
"""Get a single output value."""
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||
"""Update multiple output values."""
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count."""
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
"""Set the node run steps count."""
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
"""Increment the node run steps by 1."""
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
"""Add tokens to the total count."""
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
@property
|
||||
def ready_queue_json(self) -> str:
|
||||
"""Get a copy of the ready queue state."""
|
||||
return self._ready_queue_json
|
||||
|
||||
@property
|
||||
def graph_execution_json(self) -> str:
|
||||
"""Get a copy of the serialized graph execution state."""
|
||||
return self._graph_execution_json
|
||||
|
||||
@property
|
||||
def response_coordinator_json(self) -> str:
|
||||
"""Get a copy of the serialized response coordinator state."""
|
||||
return self._response_coordinator_json
|
||||
@ -1,21 +0,0 @@
|
||||
import hashlib
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: str | None = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: list[Condition] | None = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
||||
@ -58,6 +58,7 @@ class NodeType(StrEnum):
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
HUMAN_INPUT = "human-input"
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
@ -96,6 +97,7 @@ class WorkflowExecutionStatus(StrEnum):
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
|
||||
@ -1,16 +1,11 @@
|
||||
from .edge import Edge
|
||||
from .graph import Graph, NodeFactory
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .graph import Graph, GraphBuilder, NodeFactory
|
||||
from .graph_template import GraphTemplate
|
||||
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
|
||||
__all__ = [
|
||||
"Edge",
|
||||
"Graph",
|
||||
"GraphBuilder",
|
||||
"GraphTemplate",
|
||||
"NodeFactory",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
]
|
||||
|
||||
@ -195,6 +195,12 @@ class Graph:
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def new(cls) -> "GraphBuilder":
|
||||
"""Create a fluent builder for assembling a graph programmatically."""
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
@ -344,3 +350,96 @@ class Graph:
|
||||
"""
|
||||
edge_ids = self.in_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
|
||||
@final
|
||||
class GraphBuilder:
|
||||
"""Fluent helper for constructing simple graphs, primarily for tests."""
|
||||
|
||||
def __init__(self, *, graph_cls: type[Graph]):
|
||||
self._graph_cls = graph_cls
|
||||
self._nodes: list[Node] = []
|
||||
self._nodes_by_id: dict[str, Node] = {}
|
||||
self._edges: list[Edge] = []
|
||||
self._edge_counter = 0
|
||||
|
||||
def add_root(self, node: Node) -> "GraphBuilder":
|
||||
"""Register the root node. Must be called exactly once."""
|
||||
|
||||
if self._nodes:
|
||||
raise ValueError("Root node has already been added")
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
return self
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node: Node,
|
||||
*,
|
||||
from_node_id: str | None = None,
|
||||
source_handle: str = "source",
|
||||
) -> "GraphBuilder":
|
||||
"""Append a node and connect it from the specified predecessor."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Root node must be added before adding other nodes")
|
||||
|
||||
predecessor_id = from_node_id or self._nodes[-1].id
|
||||
if predecessor_id not in self._nodes_by_id:
|
||||
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
|
||||
|
||||
predecessor = self._nodes_by_id[predecessor_id]
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
|
||||
"""Connect two existing nodes without adding a new node."""
|
||||
|
||||
if tail not in self._nodes_by_id:
|
||||
raise ValueError(f"Tail node '{tail}' not found")
|
||||
if head not in self._nodes_by_id:
|
||||
raise ValueError(f"Head node '{head}' not found")
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def build(self) -> Graph:
|
||||
"""Materialize the graph instance from the accumulated nodes and edges."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Cannot build an empty graph")
|
||||
|
||||
nodes = {node.id: node for node in self._nodes}
|
||||
edges = {edge.id: edge for edge in self._edges}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for edge in self._edges:
|
||||
out_edges[edge.tail].append(edge.id)
|
||||
in_edges[edge.head].append(edge.id)
|
||||
|
||||
return self._graph_cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=dict(in_edges),
|
||||
out_edges=dict(out_edges),
|
||||
root_node=self._nodes[0],
|
||||
)
|
||||
|
||||
def _register_node(self, node: Node) -> None:
|
||||
if not node.id:
|
||||
raise ValueError("Node must have a non-empty id")
|
||||
if node.id in self._nodes_by_id:
|
||||
raise ValueError(f"Duplicate node id detected: {node.id}")
|
||||
self._nodes_by_id[node.id] = node
|
||||
|
||||
@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
@ -41,6 +41,7 @@ class RedisChannel:
|
||||
self._redis = redis_client
|
||||
self._key = channel_key
|
||||
self._command_ttl = command_ttl
|
||||
self._pending_key = f"{channel_key}:pending"
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
@ -49,6 +50,9 @@ class RedisChannel:
|
||||
Returns:
|
||||
List of pending commands (drains the Redis list)
|
||||
"""
|
||||
if not self._has_pending_commands():
|
||||
return []
|
||||
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
@ -85,6 +89,7 @@ class RedisChannel:
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.rpush(self._key, command_json)
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.set(self._pending_key, "1", ex=self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
@ -106,9 +111,25 @@ class RedisChannel:
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand.model_validate(data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _has_pending_commands(self) -> bool:
|
||||
"""
|
||||
Check and consume the pending marker to avoid unnecessary list reads.
|
||||
|
||||
Returns:
|
||||
True if commands should be fetched from Redis.
|
||||
"""
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.get(self._pending_key)
|
||||
pipe.delete(self._pending_key)
|
||||
pending_value, _ = pipe.execute()
|
||||
|
||||
return pending_value is not None
|
||||
|
||||
@ -5,10 +5,11 @@ This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
]
|
||||
|
||||
@ -1,14 +1,10 @@
|
||||
"""
|
||||
Command handler implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -16,17 +12,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
||||
Args:
|
||||
command: The abort command
|
||||
execution: Graph execution to abort
|
||||
"""
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
|
||||
|
||||
@final
|
||||
class PauseCommandHandler(CommandHandler):
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, PauseCommand)
|
||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.pause(command.reason)
|
||||
|
||||
@ -40,6 +40,8 @@ class GraphExecutionState(BaseModel):
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
@ -103,6 +105,8 @@ class GraphExecution:
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: str | None = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
@ -126,6 +130,17 @@ class GraphExecution:
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def pause(self, reason: str | None = None) -> None:
|
||||
"""Pause the graph execution without marking it complete."""
|
||||
if self.completed:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
if self.paused:
|
||||
return
|
||||
self.paused = True
|
||||
self.pause_reason = reason
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
self.error = error
|
||||
@ -140,7 +155,12 @@ class GraphExecution:
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the execution is currently running."""
|
||||
return self.started and not self.completed and not self.aborted
|
||||
return self.started and not self.completed and not self.aborted and not self.paused
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Check if the execution is currently paused."""
|
||||
return self.paused
|
||||
|
||||
@property
|
||||
def has_error(self) -> bool:
|
||||
@ -173,6 +193,8 @@ class GraphExecution:
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reason=self.pause_reason,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
@ -197,6 +219,8 @@ class GraphExecution:
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reason = state.pause_reason
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
|
||||
@ -16,7 +16,6 @@ class CommandType(StrEnum):
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
RESUME = "resume"
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
@ -31,3 +30,10 @@ class AbortCommand(GraphEngineCommand):
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
|
||||
|
||||
class PauseCommand(GraphEngineCommand):
|
||||
"""Command to pause a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for pause")
|
||||
|
||||
@ -7,8 +7,8 @@ from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@ -23,11 +23,13 @@ from core.workflow.graph_events import (
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
@ -125,6 +127,7 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
self._graph_runtime_state.increment_node_run_steps()
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
@ -163,6 +166,8 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
@ -199,6 +204,18 @@ class EventHandler:
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
"""Handle pause requests emitted by nodes."""
|
||||
|
||||
pause_reason = event.reason or "Awaiting human input"
|
||||
self._graph_execution.pause(pause_reason)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
if event.node_id in self._graph.nodes:
|
||||
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
|
||||
self._graph_runtime_state.register_paused_node(event.node_id)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
@ -212,6 +229,8 @@ class EventHandler:
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
@ -235,6 +254,8 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
@ -286,6 +307,19 @@ class EventHandler:
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
|
||||
"""Accumulate token usage into the shared runtime state."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
self._graph_runtime_state.add_tokens(usage.total_tokens)
|
||||
|
||||
current_usage = self._graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self._graph_runtime_state.llm_usage = usage
|
||||
else:
|
||||
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
@ -97,6 +97,10 @@ class EventManager:
|
||||
"""
|
||||
self._layers = layers
|
||||
|
||||
def notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""Notify registered layers about an event without buffering it."""
|
||||
self._notify_layers(event)
|
||||
|
||||
def collect(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Thread-safe method to collect an event.
|
||||
|
||||
@ -9,28 +9,29 @@ import contextvars
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor
|
||||
from .domain import GraphExecution
|
||||
from .entities.commands import AbortCommand
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
|
||||
from .entities.commands import AbortCommand, PauseCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
@ -38,10 +39,13 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state
|
||||
from .response_coordinator import ResponseStreamCoordinator
|
||||
from .ready_queue import ReadyQueue
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -67,17 +71,16 @@ class GraphEngine:
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
if graph_runtime_state.graph_execution_json != "":
|
||||
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
|
||||
|
||||
# === Core Dependencies ===
|
||||
# Graph structure and configuration
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
self._graph_execution.workflow_id = workflow_id
|
||||
|
||||
# === Worker Management Parameters ===
|
||||
# Parameters for dynamic worker pool scaling
|
||||
self._min_workers = min_workers
|
||||
@ -86,13 +89,7 @@ class GraphEngine:
|
||||
self._scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
# === Execution Queues ===
|
||||
# Create ready queue from saved state or initialize new one
|
||||
self._ready_queue: ReadyQueue
|
||||
if self._graph_runtime_state.ready_queue_json == "":
|
||||
self._ready_queue = InMemoryReadyQueue()
|
||||
else:
|
||||
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
|
||||
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
|
||||
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
||||
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
@ -103,11 +100,7 @@ class GraphEngine:
|
||||
|
||||
# === Response Coordination ===
|
||||
# Coordinates response streaming from response nodes
|
||||
self._response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||
)
|
||||
if graph_runtime_state.response_coordinator_json != "":
|
||||
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
|
||||
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
|
||||
|
||||
# === Event Management ===
|
||||
# Event manager handles both collection and emission of events
|
||||
@ -133,19 +126,6 @@ class GraphEngine:
|
||||
skip_propagator=self._skip_propagator,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# === Command Processing ===
|
||||
# Processes external commands (e.g., abort requests)
|
||||
self._command_processor = CommandProcessor(
|
||||
@ -153,12 +133,12 @@ class GraphEngine:
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
|
||||
# Register abort command handler
|
||||
# Register command handlers
|
||||
abort_handler = AbortCommandHandler()
|
||||
self._command_processor.register_handler(
|
||||
AbortCommand,
|
||||
abort_handler,
|
||||
)
|
||||
self._command_processor.register_handler(AbortCommand, abort_handler)
|
||||
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
@ -191,12 +171,23 @@ class GraphEngine:
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_manager,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# Dispatches events and manages execution flow
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
@ -237,26 +228,41 @@ class GraphEngine:
|
||||
# Initialize layers
|
||||
self._initialize_layers()
|
||||
|
||||
# Start execution
|
||||
self._graph_execution.start()
|
||||
is_resume = self._graph_execution.started
|
||||
if not is_resume:
|
||||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reason = None
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
self._event_manager.notify_layers(start_event)
|
||||
yield start_event
|
||||
|
||||
# Start subsystems
|
||||
self._start_execution()
|
||||
self._start_execution(resume=is_resume)
|
||||
|
||||
# Yield events as they occur
|
||||
yield from self._event_manager.emit_events()
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.aborted:
|
||||
if self._graph_execution.is_paused:
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=self._graph_execution.pause_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
yield paused_event
|
||||
elif self._graph_execution.aborted:
|
||||
abort_reason = "Workflow execution aborted by user command"
|
||||
if self._graph_execution.error:
|
||||
abort_reason = str(self._graph_execution.error)
|
||||
yield GraphRunAbortedEvent(
|
||||
aborted_event = GraphRunAbortedEvent(
|
||||
reason=abort_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(aborted_event)
|
||||
yield aborted_event
|
||||
elif self._graph_execution.has_error:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
@ -264,20 +270,26 @@ class GraphEngine:
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
partial_event = GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(partial_event)
|
||||
yield partial_event
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
succeeded_event = GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(succeeded_event)
|
||||
yield succeeded_event
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(
|
||||
failed_event = GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
self._event_manager.notify_layers(failed_event)
|
||||
yield failed_event
|
||||
raise
|
||||
|
||||
finally:
|
||||
@ -299,8 +311,12 @@ class GraphEngine:
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||
|
||||
def _start_execution(self) -> None:
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
|
||||
@ -309,10 +325,15 @@ class GraphEngine:
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._response_coordinator.register(node.id)
|
||||
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
if not resume:
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
else:
|
||||
for node_id in paused_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Start dispatcher
|
||||
self._dispatcher.start()
|
||||
|
||||
@ -7,9 +7,9 @@ intercept and respond to GraphEngine events.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
|
||||
410
api/core/workflow/graph_engine/layers/persistence.py
Normal file
410
api/core/workflow/graph_engine/layers/persistence.py
Normal file
@ -0,0 +1,410 @@
|
||||
"""Workflow persistence layer for GraphEngine.
|
||||
|
||||
This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
|
||||
listening to ``GraphEngineEvent`` instances directly and persisting workflow
|
||||
and node execution state via the injected repositories.
|
||||
|
||||
The design keeps domain persistence concerns inside the engine thread, while
|
||||
allowing presentation layers to remain read-only observers of repository
|
||||
state.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import (
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PersistenceWorkflowInfo:
|
||||
"""Static workflow metadata required for persistence."""
|
||||
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
version: str
|
||||
graph_data: Mapping[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeRuntimeSnapshot:
|
||||
"""Lightweight cache to keep node metadata across event phases."""
|
||||
|
||||
node_id: str
|
||||
title: str
|
||||
predecessor_node_id: str | None
|
||||
iteration_id: str | None
|
||||
loop_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
"""GraphEngine layer that persists workflow and node execution state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_info: PersistenceWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._trace_manager = trace_manager
|
||||
|
||||
self._workflow_execution: WorkflowExecution | None = None
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {}
|
||||
self._node_sequence: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GraphEngineLayer lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def on_graph_start(self) -> None:
|
||||
self._workflow_execution = None
|
||||
self._node_execution_cache.clear()
|
||||
self._node_snapshots.clear()
|
||||
self._node_sequence = 0
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._handle_graph_run_started()
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
self._handle_graph_run_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._handle_graph_run_partial_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
self._handle_graph_run_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunAbortedEvent):
|
||||
self._handle_graph_run_aborted(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._handle_graph_run_paused(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
self._handle_node_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunExceptionEvent):
|
||||
self._handle_node_exception(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunPauseRequestedEvent):
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph-level handlers
|
||||
# ------------------------------------------------------------------
|
||||
def _handle_graph_run_started(self) -> None:
|
||||
execution_id = self._get_execution_id()
|
||||
workflow_execution = WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
workflow_type=self._workflow_info.workflow_type,
|
||||
workflow_version=self._workflow_info.version,
|
||||
graph=self._workflow_info.graph_data,
|
||||
inputs=self._prepare_workflow_inputs(),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
self._workflow_execution = workflow_execution
|
||||
|
||||
def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.outputs = event.outputs
|
||||
execution.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.outputs = event.outputs
|
||||
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
execution.exceptions_count = event.exceptions_count
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.FAILED
|
||||
execution.error_message = event.error
|
||||
execution.exceptions_count = event.exceptions_count
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._fail_running_node_executions(error_message=event.error)
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.STOPPED
|
||||
execution.error_message = event.reason or "Workflow execution aborted"
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._fail_running_node_executions(error_message=execution.error_message or "")
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.PAUSED
|
||||
execution.error_message = event.reason or "Workflow execution paused"
|
||||
execution.outputs = event.outputs
|
||||
self._populate_completion_statistics(execution, update_finished=False)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Node-level handlers
|
||||
# ------------------------------------------------------------------
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=event.id,
|
||||
node_execution_id=event.id,
|
||||
workflow_id=execution.workflow_id,
|
||||
workflow_execution_id=execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
index=self._next_node_sequence(),
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
metadata=metadata,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
|
||||
self._node_execution_cache[event.id] = domain_execution
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
snapshot = _NodeRuntimeSnapshot(
|
||||
node_id=event.node_id,
|
||||
title=event.node_title,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
domain_execution.status = WorkflowNodeExecutionStatus.RETRY
|
||||
domain_execution.error = event.error
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.PAUSED,
|
||||
error=event.reason,
|
||||
update_outputs=False,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_execution_id(self) -> str:
|
||||
workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID)
|
||||
if not workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows")
|
||||
return str(workflow_execution_id)
|
||||
|
||||
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for field_name, value in self._system_variables().items():
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID.value:
|
||||
# Conversation IDs are tied to the current session; omit them so persisted
|
||||
# workflow inputs stay reusable without binding future runs to this conversation.
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
handled = WorkflowEntry.handle_special_values(inputs)
|
||||
return handled or {}
|
||||
|
||||
def _get_workflow_execution(self) -> WorkflowExecution:
|
||||
if self._workflow_execution is None:
|
||||
raise ValueError("workflow execution not initialized")
|
||||
return self._workflow_execution
|
||||
|
||||
def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
if node_execution_id not in self._node_execution_cache:
|
||||
raise ValueError(f"Node execution not found for id={node_execution_id}")
|
||||
return self._node_execution_cache[node_execution_id]
|
||||
|
||||
def _next_node_sequence(self) -> int:
|
||||
self._node_sequence += 1
|
||||
return self._node_sequence
|
||||
|
||||
def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None:
|
||||
if update_finished:
|
||||
execution.finished_at = naive_utc_now()
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
domain_execution: WorkflowNodeExecution,
|
||||
node_result: NodeRunResult,
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
*,
|
||||
error: str | None = None,
|
||||
update_outputs: bool = True,
|
||||
) -> None:
|
||||
finished_at = naive_utc_now()
|
||||
snapshot = self._node_snapshots.get(domain_execution.id)
|
||||
start_at = snapshot.created_at if snapshot else domain_execution.created_at
|
||||
domain_execution.status = status
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
if update_outputs:
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=node_result.inputs,
|
||||
process_data=node_result.process_data,
|
||||
outputs=node_result.outputs,
|
||||
metadata=node_result.metadata,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
|
||||
def _fail_running_node_executions(self, *, error_message: str) -> None:
|
||||
now = naive_utc_now()
|
||||
for execution in self._node_execution_cache.values():
|
||||
if execution.status == WorkflowNodeExecutionStatus.RUNNING:
|
||||
execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
execution.error = error_message
|
||||
execution.finished_at = now
|
||||
execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0)
|
||||
self._workflow_node_execution_repository.save(execution)
|
||||
|
||||
def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:
|
||||
if not self._trace_manager:
|
||||
return
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return {}
|
||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
@ -9,7 +9,7 @@ Supports stop, pause, and resume operations.
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ class GraphEngineManager:
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop, pause, and resume operations.
|
||||
Supports stop and pause operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -32,19 +32,29 @@ class GraphEngineManager:
|
||||
task_id: The task ID of the workflow to stop
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
GraphEngineManager._send_command(task_id, abort_command)
|
||||
|
||||
@staticmethod
|
||||
def send_pause_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Create Redis channel for this task
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
# Create and send abort command
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
|
||||
try:
|
||||
channel.send_command(abort_command)
|
||||
channel.send_command(command)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy stop flag mechanism will still work
|
||||
# The legacy control mechanisms will still work
|
||||
pass
|
||||
|
||||
@ -8,7 +8,12 @@ import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from ..event_management import EventManager
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
@ -28,6 +33,12 @@ class Dispatcher:
|
||||
with timeout and completion detection.
|
||||
"""
|
||||
|
||||
_COMMAND_TRIGGER_EVENTS = (
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
@ -76,22 +87,37 @@ class Dispatcher:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for commands
|
||||
self._execution_coordinator.check_commands()
|
||||
commands_checked = False
|
||||
should_check_commands = False
|
||||
should_break = False
|
||||
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
should_check_commands = True
|
||||
should_break = True
|
||||
else:
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Check if execution is complete
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
break
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
should_check_commands = self._should_check_commands(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Process commands even when no new events arrive so abort requests are not missed
|
||||
should_check_commands = True
|
||||
time.sleep(0.1)
|
||||
|
||||
if should_check_commands and not commands_checked:
|
||||
self._execution_coordinator.check_commands()
|
||||
commands_checked = True
|
||||
|
||||
if should_break:
|
||||
if not commands_checked:
|
||||
self._execution_coordinator.check_commands()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
@ -102,3 +128,7 @@ class Dispatcher:
|
||||
# Signal the event emitter that execution is complete
|
||||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
|
||||
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
|
||||
"""Return True if the event represents a node completion."""
|
||||
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)
|
||||
|
||||
@ -2,17 +2,13 @@
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventManager
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
@ -27,8 +23,6 @@ class ExecutionCoordinator:
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
state_manager: GraphStateManager,
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
@ -38,15 +32,11 @@ class ExecutionCoordinator:
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
state_manager: Unified state manager
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event manager for collecting events
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
@ -65,15 +55,24 @@ class ExecutionCoordinator:
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
# Check if aborted or failed
|
||||
# Treat paused, aborted, or failed executions as terminal states
|
||||
if self._graph_execution.is_paused:
|
||||
return True
|
||||
|
||||
if self._graph_execution.aborted or self._graph_execution.has_error:
|
||||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Expose whether the underlying graph execution is paused."""
|
||||
return self._graph_execution.is_paused
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if self._graph_execution.is_paused:
|
||||
return
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
@ -85,3 +84,21 @@ class ExecutionCoordinator:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self._graph_execution.fail(error)
|
||||
|
||||
def handle_pause_if_needed(self) -> None:
|
||||
"""If the execution has been paused, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.is_paused:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
def handle_abort_if_needed(self) -> None:
|
||||
"""If the execution has been aborted, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.aborted:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
@ -14,11 +14,11 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
@ -13,6 +13,7 @@ from .graph import (
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
@ -37,6 +38,7 @@ from .loop import (
|
||||
from .node import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
@ -51,6 +53,7 @@ __all__ = [
|
||||
"GraphRunAbortedEvent",
|
||||
"GraphRunFailedEvent",
|
||||
"GraphRunPartialSucceededEvent",
|
||||
"GraphRunPausedEvent",
|
||||
"GraphRunStartedEvent",
|
||||
"GraphRunSucceededEvent",
|
||||
"NodeRunAgentLogEvent",
|
||||
@ -64,6 +67,7 @@ __all__ = [
|
||||
"NodeRunLoopNextEvent",
|
||||
"NodeRunLoopStartedEvent",
|
||||
"NodeRunLoopSucceededEvent",
|
||||
"NodeRunPauseRequestedEvent",
|
||||
"NodeRunRetrieverResourceEvent",
|
||||
"NodeRunRetryEvent",
|
||||
"NodeRunStartedEvent",
|
||||
|
||||
@ -8,7 +8,12 @@ class GraphRunStartedEvent(BaseGraphEvent):
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: dict[str, object] = Field(default_factory=dict)
|
||||
"""Event emitted when a run completes successfully with final outputs."""
|
||||
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Final workflow outputs keyed by output selector.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
@ -17,12 +22,30 @@ class GraphRunFailedEvent(BaseGraphEvent):
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
"""Event emitted when a run finishes with partial success and failures."""
|
||||
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: dict[str, object] = Field(default_factory=dict)
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs that were materialised before failures occurred.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is aborted by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for abort")
|
||||
outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs produced before the abort was requested.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for pause")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
)
|
||||
|
||||
@ -51,3 +51,7 @@ class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
|
||||
@ -14,6 +14,7 @@ from .loop import (
|
||||
)
|
||||
from .node import (
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
StreamChunkEvent,
|
||||
@ -33,6 +34,7 @@ __all__ = [
|
||||
"ModelInvokeCompletedEvent",
|
||||
"NodeEventBase",
|
||||
"NodeRunResult",
|
||||
"PauseRequestedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"StreamChunkEvent",
|
||||
|
||||
@ -40,3 +40,7 @@ class StreamChunkEvent(NodeEventBase):
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
node_run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
|
||||
@ -25,7 +25,6 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
@ -44,6 +43,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@ -20,6 +20,7 @@ from core.workflow.graph_events import (
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
@ -37,10 +38,12 @@ from core.workflow.node_events import (
|
||||
LoopSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
|
||||
@ -385,6 +388,16 @@ class Node:
|
||||
f"Node {self._node_id} does not support status {event.node_run_result.status}"
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
|
||||
return NodeRunPauseRequestedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
||||
reason=event.reason,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
return NodeRunAgentLogEvent(
|
||||
|
||||
@ -19,7 +19,6 @@ from core.file.enums import FileTransferMethod, FileType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
@ -27,6 +26,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
|
||||
@ -15,7 +15,7 @@ from core.file import file_manager
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
|
||||
3
api/core/workflow/nodes/human_input/__init__.py
Normal file
3
api/core/workflow/nodes/human_input/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .human_input_node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode"]
|
||||
10
api/core/workflow/nodes/human_input/entities.py
Normal file
10
api/core/workflow/nodes/human_input/entities.py
Normal file
@ -0,0 +1,10 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Configuration schema for the HumanInput node."""
|
||||
|
||||
required_variables: list[str] = Field(default_factory=list)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
132
api/core/workflow/nodes/human_input/human_input_node.py
Normal file
132
api/core/workflow/nodes/human_input/human_input_node.py
Normal file
@ -0,0 +1,132 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
|
||||
|
||||
class HumanInputNode(Node):
|
||||
node_type = NodeType.HUMAN_INPUT
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
|
||||
"edge_source_handle",
|
||||
"edgeSourceHandle",
|
||||
"source_handle",
|
||||
"selected_branch",
|
||||
"selectedBranch",
|
||||
"branch",
|
||||
"branch_id",
|
||||
"branchId",
|
||||
"handle",
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = HumanInputNodeData(**data)
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def _run(self): # type: ignore[override]
|
||||
if self._is_completion_ready():
|
||||
branch_handle = self._resolve_branch_selection()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
edge_source_handle=branch_handle or "source",
|
||||
)
|
||||
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
if not self._node_data.required_variables:
|
||||
return False
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for selector_str in self._node_data.required_variables:
|
||||
parts = selector_str.split(".")
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
segment = variable_pool.get(parts)
|
||||
if segment is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_branch_selection(self) -> str | None:
|
||||
"""Determine the branch handle selected by human input if available."""
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for key in self._BRANCH_SELECTION_KEYS:
|
||||
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
|
||||
if handle:
|
||||
return handle
|
||||
|
||||
default_values = self._node_data.default_value_dict
|
||||
for key in self._BRANCH_SELECTION_KEYS:
|
||||
handle = self._normalize_branch_value(default_values.get(key))
|
||||
if handle:
|
||||
return handle
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_branch_handle(segment: Any) -> str | None:
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
candidate = getattr(segment, "to_object", None)
|
||||
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
|
||||
if raw_value is None:
|
||||
return None
|
||||
|
||||
return HumanInputNode._normalize_branch_value(raw_value)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_branch_value(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
|
||||
candidate = value.get(key)
|
||||
if isinstance(candidate, str) and candidate:
|
||||
return candidate
|
||||
|
||||
return None
|
||||
@ -3,12 +3,12 @@ from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
@ -38,6 +37,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
@ -557,11 +557,12 @@ class IterationNode(Node):
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@ -9,13 +9,13 @@ from sqlalchemy import func, select
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ from .exc import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -15,9 +15,9 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import Conversation
|
||||
|
||||
@ -52,7 +52,7 @@ from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
@ -71,6 +71,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
@ -93,7 +94,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -406,11 +406,12 @@ class LoopNode(Node):
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@ -10,7 +10,8 @@ from libs.typing import is_str, is_str_dict
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@ -9,6 +9,7 @@ from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
@ -134,6 +135,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: {
|
||||
LATEST_VERSION: HumanInputNode,
|
||||
"1": HumanInputNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
|
||||
@ -27,13 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.runtime import VariablePool
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@ -40,7 +41,7 @@ from .template_prompts import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class QuestionClassifierNode(Node):
|
||||
@ -194,6 +195,8 @@ class QuestionClassifierNode(Node):
|
||||
|
||||
category_name = node_data.classes[0].name
|
||||
category_id = node_data.classes[0].id
|
||||
if "<think>" in result_text:
|
||||
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
if "category_name" in result_text_json and "category_id" in result_text_json:
|
||||
|
||||
@ -36,7 +36,7 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node):
|
||||
|
||||
@ -18,7 +18,7 @@ from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
14
api/core/workflow/runtime/__init__.py
Normal file
14
api/core/workflow/runtime/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
|
||||
__all__ = [
|
||||
"GraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
"VariablePool",
|
||||
"VariableValue",
|
||||
]
|
||||
393
api/core/workflow/runtime/graph_runtime_state.py
Normal file
393
api/core/workflow/runtime/graph_runtime_state.py
Normal file
@ -0,0 +1,393 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping as TypingMapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadyQueueProtocol(Protocol):
|
||||
"""Structural interface required from ready queue implementations."""
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""Enqueue the identifier of a node that is ready to run."""
|
||||
...
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""Return the next node identifier, blocking until available or timeout expires."""
|
||||
...
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""Signal that the most recently dequeued node has completed processing."""
|
||||
...
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Return True when the queue contains no pending nodes."""
|
||||
...
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Approximate the number of pending nodes awaiting execution."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the queue contents for persistence."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore the queue contents from a serialized payload."""
|
||||
...
|
||||
|
||||
|
||||
class GraphExecutionProtocol(Protocol):
|
||||
"""Structural interface for graph execution aggregate."""
|
||||
|
||||
workflow_id: str
|
||||
started: bool
|
||||
completed: bool
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
...
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Mark execution as successfully completed."""
|
||||
...
|
||||
|
||||
def abort(self, reason: str) -> None:
|
||||
"""Abort execution in response to an external stop request."""
|
||||
...
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Record an unrecoverable error and end execution."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize execution state into a JSON payload."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore execution state from a previously serialized payload."""
|
||||
...
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
"""Structural interface for response stream coordinator."""
|
||||
|
||||
def register(self, response_node_id: str) -> None:
|
||||
"""Register a response node so its outputs can be streamed."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore coordinator state from a serialized payload."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize coordinator state for persistence."""
|
||||
...
|
||||
|
||||
|
||||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: TypingMapping[str, object]
|
||||
edges: TypingMapping[str, object]
|
||||
root_node: object
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
|
||||
|
||||
class GraphRuntimeState:
|
||||
"""Mutable runtime state shared across graph execution components."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, object] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue: ReadyQueueProtocol | None = None,
|
||||
graph_execution: GraphExecutionProtocol | None = None,
|
||||
response_coordinator: ResponseStreamCoordinatorProtocol | None = None,
|
||||
graph: GraphProtocol | None = None,
|
||||
) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy()
|
||||
self._outputs = deepcopy(outputs) if outputs is not None else {}
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
self._graph: GraphProtocol | None = None
|
||||
|
||||
self._ready_queue = ready_queue
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context binding helpers
|
||||
# ------------------------------------------------------------------
|
||||
def attach_graph(self, graph: GraphProtocol) -> None:
|
||||
"""Attach the materialized graph to the runtime state."""
|
||||
if self._graph is not None and self._graph is not graph:
|
||||
raise ValueError("GraphRuntimeState already attached to a different graph instance")
|
||||
|
||||
self._graph = graph
|
||||
|
||||
if self._response_coordinator is None:
|
||||
self._response_coordinator = self._build_response_coordinator(graph)
|
||||
|
||||
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
||||
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
||||
self._pending_response_coordinator_dump = None
|
||||
|
||||
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
||||
"""Ensure core collaborators are initialized with the provided context."""
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
# Ensure collaborators are instantiated
|
||||
_ = self.ready_queue
|
||||
_ = self.graph_execution
|
||||
if self._graph is not None:
|
||||
_ = self.response_coordinator
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Primary collaborators
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def ready_queue(self) -> ReadyQueueProtocol:
|
||||
if self._ready_queue is None:
|
||||
self._ready_queue = self._build_ready_queue()
|
||||
return self._ready_queue
|
||||
|
||||
@property
|
||||
def graph_execution(self) -> GraphExecutionProtocol:
|
||||
if self._graph_execution is None:
|
||||
self._graph_execution = self._build_graph_execution()
|
||||
return self._graph_execution
|
||||
|
||||
@property
|
||||
def response_coordinator(self) -> ResponseStreamCoordinatorProtocol:
|
||||
if self._response_coordinator is None:
|
||||
if self._graph is None:
|
||||
raise ValueError("Graph must be attached before accessing response coordinator")
|
||||
self._response_coordinator = self._build_response_coordinator(self._graph)
|
||||
return self._response_coordinator
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scalar state
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int) -> None:
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage) -> None:
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, Any]) -> None:
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: object) -> None:
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ------------------------------------------------------------------
|
||||
def dumps(self) -> str:
|
||||
"""Serialize runtime state into a JSON string."""
|
||||
|
||||
snapshot: dict[str, Any] = {
|
||||
"version": "1.0",
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"llm_usage": self._llm_usage.model_dump(mode="json"),
|
||||
"outputs": self.outputs,
|
||||
"variable_pool": self.variable_pool.model_dump(mode="json"),
|
||||
"ready_queue": self.ready_queue.dumps(),
|
||||
"graph_execution": self.graph_execution.dumps(),
|
||||
"paused_nodes": list(self._paused_nodes),
|
||||
}
|
||||
|
||||
if self._response_coordinator is not None and self._graph is not None:
|
||||
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
||||
|
||||
return json.dumps(snapshot, default=pydantic_encoder)
|
||||
|
||||
def loads(self, data: str | Mapping[str, Any]) -> None:
|
||||
"""Restore runtime state from a serialized snapshot."""
|
||||
|
||||
payload: dict[str, Any]
|
||||
if isinstance(data, str):
|
||||
payload = json.loads(data)
|
||||
else:
|
||||
payload = dict(data)
|
||||
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
|
||||
|
||||
self._start_at = float(payload.get("start_at", 0.0))
|
||||
total_tokens = int(payload.get("total_tokens", 0))
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
node_run_steps = int(payload.get("node_run_steps", 0))
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
llm_usage_payload = payload.get("llm_usage", {})
|
||||
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
|
||||
|
||||
self._outputs = deepcopy(payload.get("outputs", {}))
|
||||
|
||||
variable_pool_payload = payload.get("variable_pool")
|
||||
if variable_pool_payload is not None:
|
||||
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
|
||||
|
||||
ready_queue_payload = payload.get("ready_queue")
|
||||
if ready_queue_payload is not None:
|
||||
self._ready_queue = self._build_ready_queue()
|
||||
self._ready_queue.loads(ready_queue_payload)
|
||||
else:
|
||||
self._ready_queue = None
|
||||
|
||||
graph_execution_payload = payload.get("graph_execution")
|
||||
self._graph_execution = None
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
if graph_execution_payload is not None:
|
||||
try:
|
||||
execution_payload = json.loads(graph_execution_payload)
|
||||
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
self.graph_execution.loads(graph_execution_payload)
|
||||
|
||||
response_payload = payload.get("response_coordinator")
|
||||
if response_payload is not None:
|
||||
if self._graph is not None:
|
||||
self.response_coordinator.loads(response_payload)
|
||||
else:
|
||||
self._pending_response_coordinator_dump = response_payload
|
||||
else:
|
||||
self._pending_response_coordinator_dump = None
|
||||
self._response_coordinator = None
|
||||
|
||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||
self._paused_nodes = set(map(str, paused_nodes_payload))
|
||||
|
||||
def register_paused_node(self, node_id: str) -> None:
|
||||
"""Record a node that should resume when execution is continued."""
|
||||
|
||||
self._paused_nodes.add(node_id)
|
||||
|
||||
def consume_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
||||
|
||||
nodes = list(self._paused_nodes)
|
||||
self._paused_nodes.clear()
|
||||
return nodes
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Builders
|
||||
# ------------------------------------------------------------------
|
||||
def _build_ready_queue(self) -> ReadyQueueProtocol:
|
||||
# Import lazily to avoid breaching architecture boundaries enforced by import-linter.
|
||||
module = importlib.import_module("core.workflow.graph_engine.ready_queue")
|
||||
in_memory_cls = module.InMemoryReadyQueue
|
||||
return in_memory_cls()
|
||||
|
||||
def _build_graph_execution(self) -> GraphExecutionProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution")
|
||||
graph_execution_cls = module.GraphExecution
|
||||
workflow_id = self._pending_graph_execution_workflow_id or ""
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
return graph_execution_cls(workflow_id=workflow_id)
|
||||
|
||||
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
|
||||
coordinator_cls = module.ResponseStreamCoordinator
|
||||
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
|
||||
@ -16,6 +16,10 @@ class ReadOnlyVariablePool(Protocol):
|
||||
"""Get all variables for a node (read-only)."""
|
||||
...
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
|
||||
"""Get all variables stored under a given node prefix (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""
|
||||
@ -56,6 +60,20 @@ class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""Get the node run steps count (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
"""Get the number of nodes currently in the ready queue."""
|
||||
...
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
"""Get the number of node execution exceptions recorded."""
|
||||
...
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the runtime state into a JSON snapshot (read-only)."""
|
||||
...
|
||||
@ -1,77 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadOnlyVariablePoolWrapper:
|
||||
"""Wrapper that provides read-only access to VariablePool."""
|
||||
"""Provide defensive, read-only access to ``VariablePool``."""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool):
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (returns a defensive copy)."""
|
||||
"""Return a copy of a variable value if present."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (returns defensive copies)."""
|
||||
"""Return a copy of all variables for the specified node."""
|
||||
variables: dict[str, object] = {}
|
||||
if node_id in self._variable_pool.variable_dictionary:
|
||||
for key, var in self._variable_pool.variable_dictionary[node_id].items():
|
||||
# Variables have a value property that contains the actual data
|
||||
variables[key] = deepcopy(var.value)
|
||||
for key, variable in self._variable_pool.variable_dictionary[node_id].items():
|
||||
variables[key] = deepcopy(variable.value)
|
||||
return variables
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
|
||||
"""Return a copy of all variables stored under the given prefix."""
|
||||
return self._variable_pool.get_by_prefix(prefix)
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeStateWrapper:
|
||||
"""
|
||||
Wrapper that provides read-only access to GraphRuntimeState.
|
||||
"""Expose a defensive, read-only view of ``GraphRuntimeState``."""
|
||||
|
||||
This wrapper ensures that layers can observe the state without
|
||||
modifying it. All returned values are defensive copies.
|
||||
"""
|
||||
|
||||
def __init__(self, state: GraphRuntimeState):
|
||||
def __init__(self, state: GraphRuntimeState) -> None:
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
"""Get read-only access to the variable pool."""
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
return self._state.start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
return self._state.total_tokens
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
# Return a copy to prevent modification
|
||||
return self._state.llm_usage.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
return deepcopy(self._state.outputs)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
return self._state.node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._state.ready_queue.qsize()
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._state.graph_execution.exceptions_count
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
return self._state.get_output(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the underlying runtime state for external persistence."""
|
||||
return self._state.dumps()
|
||||
@ -1,6 +1,7 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Annotated, Any, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@ -235,6 +236,20 @@ class VariablePool(BaseModel):
|
||||
return segment
|
||||
return None
|
||||
|
||||
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
|
||||
"""Return a copy of all variables stored under the given node prefix."""
|
||||
|
||||
nodes = self.variable_dictionary.get(prefix)
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
result: dict[str, object] = {}
|
||||
for key, variable in nodes.items():
|
||||
value = variable.value
|
||||
result[key] = deepcopy(value)
|
||||
|
||||
return result
|
||||
|
||||
def _add_system_variables(self, system_variable: SystemVariable):
|
||||
sys_var_mapping = system_variable.to_dict()
|
||||
for key, value in sys_var_mapping.items():
|
||||
@ -5,7 +5,7 @@ from typing import Literal, NamedTuple
|
||||
from core.file import FileAttribute, file_manager
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .entities import Condition, SubCondition, SupportedComparisonOperator
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class VariableLoader(Protocol):
|
||||
|
||||
@ -1,459 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.entities import (
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.enums import (
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
|
||||
@dataclass
|
||||
class CycleManagerWorkflowInfo:
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
version: str
|
||||
graph_data: Mapping[str, Any]
|
||||
|
||||
|
||||
class WorkflowCycleManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: SystemVariable,
|
||||
workflow_info: CycleManagerWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
# Initialize caches for workflow execution cycle
|
||||
# These caches avoid redundant repository calls during a single workflow execution
|
||||
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
|
||||
def handle_workflow_run_start(self) -> WorkflowExecution:
|
||||
inputs = self._prepare_workflow_inputs()
|
||||
execution_id = self._get_or_generate_execution_id()
|
||||
|
||||
execution = WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
workflow_type=self._workflow_info.workflow_type,
|
||||
workflow_version=self._workflow_info.version,
|
||||
graph=self._workflow_info.graph_data,
|
||||
inputs=inputs,
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
return self._save_and_cache_workflow_execution(execution)
|
||||
|
||||
def handle_workflow_run_success(
|
||||
self,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
self._update_workflow_execution_completion(
|
||||
workflow_execution,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
return workflow_execution
|
||||
|
||||
def handle_workflow_run_partial_success(
|
||||
self,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
exceptions_count: int = 0,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
self._update_workflow_execution_completion(
|
||||
execution,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
outputs=outputs,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
exceptions_count=exceptions_count,
|
||||
)
|
||||
|
||||
self._add_trace_task_if_needed(trace_manager, execution, conversation_id, external_trace_id)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
return execution
|
||||
|
||||
def handle_workflow_run_failed(
|
||||
self,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowExecutionStatus,
|
||||
error_message: str,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
exceptions_count: int = 0,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
now = naive_utc_now()
|
||||
|
||||
self._update_workflow_execution_completion(
|
||||
workflow_execution,
|
||||
status=status,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
error_message=error_message,
|
||||
exceptions_count=exceptions_count,
|
||||
finished_at=now,
|
||||
)
|
||||
|
||||
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
|
||||
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
return workflow_execution
|
||||
|
||||
def handle_node_execution_start(
|
||||
self,
|
||||
*,
|
||||
workflow_execution_id: str,
|
||||
event: QueueNodeStartedEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
|
||||
domain_execution = self._create_node_execution_from_event(
|
||||
workflow_execution=workflow_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
)
|
||||
|
||||
return self._save_and_cache_node_execution(domain_execution)
|
||||
|
||||
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
|
||||
|
||||
self._update_node_execution_completion(
|
||||
domain_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_failed(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
|
||||
|
||||
status = (
|
||||
WorkflowNodeExecutionStatus.EXCEPTION
|
||||
if isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.FAILED
|
||||
)
|
||||
|
||||
self._update_node_execution_completion(
|
||||
domain_execution,
|
||||
event=event,
|
||||
status=status,
|
||||
error=event.error,
|
||||
handle_special_values=True,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_retried(
|
||||
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
|
||||
domain_execution = self._create_node_execution_from_event(
|
||||
workflow_execution=workflow_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
error=event.error,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
|
||||
# Handle inputs and outputs
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = event.outputs
|
||||
metadata = self._merge_event_metadata(event)
|
||||
|
||||
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
|
||||
|
||||
execution = self._save_and_cache_node_execution(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(execution)
|
||||
return execution
|
||||
|
||||
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
|
||||
# Check cache first
|
||||
if id in self._workflow_execution_cache:
|
||||
return self._workflow_execution_cache[id]
|
||||
|
||||
raise WorkflowRunNotFoundError(id)
|
||||
|
||||
def _prepare_workflow_inputs(self) -> dict[str, Any]:
|
||||
"""Prepare workflow inputs by merging application inputs with system variables."""
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
|
||||
if self._workflow_system_variables:
|
||||
for field_name, value in self._workflow_system_variables.to_dict().items():
|
||||
if field_name != SystemVariableKey.CONVERSATION_ID:
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
|
||||
return dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
def _get_or_generate_execution_id(self) -> str:
|
||||
"""Get execution ID from system variables or generate a new one."""
|
||||
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
|
||||
return str(self._workflow_system_variables.workflow_execution_id)
|
||||
return str(uuidv7())
|
||||
|
||||
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
|
||||
"""Save workflow execution to repository and cache it."""
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._workflow_execution_cache[execution.id_] = execution
|
||||
return execution
|
||||
|
||||
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
|
||||
"""Save node execution to repository and cache it if it has an ID.
|
||||
|
||||
This does not persist the `inputs` / `process_data` / `outputs` fields of the execution model.
|
||||
"""
|
||||
self._workflow_node_execution_repository.save(execution)
|
||||
if execution.node_execution_id:
|
||||
self._node_execution_cache[execution.node_execution_id] = execution
|
||||
return execution
|
||||
|
||||
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""Get node execution from cache or raise error if not found."""
|
||||
domain_execution = self._node_execution_cache.get(node_execution_id)
|
||||
if not domain_execution:
|
||||
raise ValueError(f"Domain node execution not found: {node_execution_id}")
|
||||
return domain_execution
|
||||
|
||||
def _update_workflow_execution_completion(
|
||||
self,
|
||||
execution: WorkflowExecution,
|
||||
*,
|
||||
status: WorkflowExecutionStatus,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
error_message: str | None = None,
|
||||
exceptions_count: int = 0,
|
||||
finished_at: datetime | None = None,
|
||||
):
|
||||
"""Update workflow execution with completion data."""
|
||||
execution.status = status
|
||||
execution.outputs = outputs or {}
|
||||
execution.total_tokens = total_tokens
|
||||
execution.total_steps = total_steps
|
||||
execution.finished_at = finished_at or naive_utc_now()
|
||||
execution.exceptions_count = exceptions_count
|
||||
if error_message:
|
||||
execution.error_message = error_message
|
||||
|
||||
def _add_trace_task_if_needed(
|
||||
self,
|
||||
trace_manager: TraceQueueManager | None,
|
||||
workflow_execution: WorkflowExecution,
|
||||
conversation_id: str | None,
|
||||
external_trace_id: str | None,
|
||||
):
|
||||
"""Add trace task if trace manager is provided."""
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=workflow_execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
)
|
||||
)
|
||||
|
||||
def _fail_running_node_executions(
|
||||
self,
|
||||
workflow_execution_id: str,
|
||||
error_message: str,
|
||||
now: datetime,
|
||||
):
|
||||
"""Fail all running node executions for a workflow."""
|
||||
running_node_executions = [
|
||||
node_exec
|
||||
for node_exec in self._node_execution_cache.values()
|
||||
if node_exec.workflow_execution_id == workflow_execution_id
|
||||
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
]
|
||||
|
||||
for node_execution in running_node_executions:
|
||||
if node_execution.node_execution_id:
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = error_message
|
||||
node_execution.finished_at = now
|
||||
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
|
||||
self._workflow_node_execution_repository.save(node_execution)
|
||||
|
||||
def _create_node_execution_from_event(
|
||||
self,
|
||||
*,
|
||||
workflow_execution: WorkflowExecution,
|
||||
event: QueueNodeStartedEvent,
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: str | None = None,
|
||||
created_at: datetime | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a node execution from an event."""
|
||||
now = naive_utc_now()
|
||||
created_at = created_at or now
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=event.node_execution_id,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
workflow_execution_id=workflow_execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
index=event.node_run_index,
|
||||
node_execution_id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
status=status,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
error=error,
|
||||
)
|
||||
|
||||
if status == WorkflowNodeExecutionStatus.RETRY:
|
||||
domain_execution.finished_at = now
|
||||
domain_execution.elapsed_time = (now - created_at).total_seconds()
|
||||
|
||||
return domain_execution
|
||||
|
||||
def _update_node_execution_completion(
|
||||
self,
|
||||
domain_execution: WorkflowNodeExecution,
|
||||
*,
|
||||
event: Union[
|
||||
QueueNodeSucceededEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
],
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: str | None = None,
|
||||
handle_special_values: bool = False,
|
||||
):
|
||||
"""Update node execution with completion data."""
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
# Process data
|
||||
if handle_special_values:
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
else:
|
||||
inputs = event.inputs
|
||||
process_data = event.process_data
|
||||
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata
|
||||
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
if event.execution_metadata:
|
||||
execution_metadata_dict.update(event.execution_metadata)
|
||||
|
||||
# Update domain model
|
||||
domain_execution.status = status
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata=execution_metadata_dict,
|
||||
)
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = elapsed_time
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
|
||||
"""Merge event metadata with origin metadata."""
|
||||
origin_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
|
||||
if event.execution_metadata:
|
||||
execution_metadata_dict.update(event.execution_metadata)
|
||||
|
||||
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
||||
@ -9,7 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
@ -20,6 +20,7 @@ from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, Gra
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from factories import file_factory
|
||||
|
||||
Reference in New Issue
Block a user